Skip to content

Commit

Permalink
Add memoize.{update|upsert}
Browse files Browse the repository at this point in the history
Resolve #56
  • Loading branch information
Cameron Evans committed Mar 26, 2020
1 parent f8f4a83 commit 393d957
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 93 deletions.
227 changes: 153 additions & 74 deletions atools/_memoize_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from textwrap import dedent
from time import time
from threading import Lock as SyncLock
from typing import Any, Callable, Hashable, Mapping, Optional, Tuple, Type, Union
from typing import Any, Callable, Iterable, Hashable, Mapping, Optional, Tuple, Type, Union
from weakref import finalize, WeakSet


Expand All @@ -32,7 +32,6 @@ class _MemoReturnState:

@dataclass(frozen=True)
class _MemoBase:
fn: Callable
t0: Optional[float]
memo_return_state: _MemoReturnState = field(init=False, default_factory=_MemoReturnState)

Expand Down Expand Up @@ -90,7 +89,7 @@ def __post_init__(self) -> None:
for k, t0, t, v in self.db.execute(
f"SELECT k, t0, t, v FROM `{self.table_name}` ORDER BY t"
).fetchall():
memo = self.make_memo(fn=self.fn, t0=t0)
memo = self.make_memo(t0=t0)
memo.memo_return_state.called = True
memo.memo_return_state.value = pickle.loads(v)
self.memos[k] = memo
Expand All @@ -114,7 +113,7 @@ def table_name(self) -> str:

def bind_key_lifetime(self, raw_key: Tuple[Any, ...], key: Union[int, str]) -> None:
for raw_key_part in raw_key:
if (raw_key_part is not None) and (type(raw_key_part).__hash__ is object.__hash__):
if type(raw_key_part).__hash__ is object.__hash__:
finalize(raw_key_part, self.reset_key, key)

def default_keygen(self, *args, **kwargs) -> Tuple[Hashable, ...]:
Expand All @@ -128,21 +127,23 @@ def get_args_as_kwargs(self, *args, **kwargs) -> Mapping[str, Any]:
args_as_kwargs[k] = v
return ChainMap(args_as_kwargs, kwargs, self.default_kwargs)

def get_memo(self, key: Union[int, str]) -> _Memo:
def get_memo(self, key: Union[int, str], insert: bool) -> Optional[_Memo]:
try:
memo = self.memos[key] = self.memos.pop(key)
if self.duration is not None and memo.t0 < time() - self.duration.total_seconds():
self.expire_order.pop(key)
raise ValueError('value expired')
except (KeyError, ValueError):
if self.duration is None:
if not insert:
return None
elif self.duration is None:
t0 = None
else:
t0 = time()
# The value has no significance. We're using the dict entirely for ordering keys.
self.expire_order[key] = ...

memo = self.memos[key] = self.make_memo(self.fn, t0=t0)
memo = self.memos[key] = self.make_memo(t0=t0)

return memo

Expand All @@ -166,24 +167,23 @@ def expire_one_memo(self) -> None:
def finalize_memo(self, memo: _Memo, key: Union[int, str]) -> Any:
if memo.memo_return_state.raised:
raise memo.memo_return_state.value
else:
if (self.db is not None) and (self.memos[key] is memo):
value = pickle.dumps(memo.memo_return_state.value)
self.db.execute(
dedent(f'''
INSERT OR REPLACE INTO `{self.table_name}`
(k, t0, t, v)
VALUES
(?, ?, ?, ?)
'''),
(
key,
memo.t0,
time(),
value
)
elif (self.db is not None) and (self.memos[key] is memo):
value = pickle.dumps(memo.memo_return_state.value)
self.db.execute(
dedent(f'''
INSERT OR REPLACE INTO `{self.table_name}`
(k, t0, t, v)
VALUES
(?, ?, ?, ?)
'''),
(
key,
memo.t0,
time(),
value
)
return memo.memo_return_state.value
)
return memo.memo_return_state.value

def get_key(self, raw_key: Tuple[Hashable, ...]) -> Union[int, str]:
if self.db is None:
Expand All @@ -194,9 +194,9 @@ def get_key(self, raw_key: Tuple[Hashable, ...]) -> Union[int, str]:
return key

@staticmethod
def make_memo(fn, t0: Optional[float]) -> _Memo: # pragma: no cover
def make_memo(t0: Optional[float]) -> _Memo: # pragma: no cover
raise NotImplemented

def reset(self) -> None:
object.__setattr__(self, 'expire_order', OrderedDict())
object.__setattr__(self, 'memos', OrderedDict())
Expand Down Expand Up @@ -232,47 +232,86 @@ async def get_raw_key(self, *args, **kwargs) -> Tuple[Hashable, ...]:

return raw_key

def get_decorator(self) -> Callable:
async def decorator(*args, **kwargs) -> Any:
raw_key = await self.get_raw_key(*args, **kwargs)
key = self.get_key(raw_key)
def get_behavior(self, *, insert: bool, update: bool) -> Callable:
def get_call(*, fn: Callable) -> Callable:

memo: _AsyncMemo = self.get_memo(key)
@wraps(self.fn)
async def call(*args, **kwargs) -> Any:
raw_key = await self.get_raw_key(*args, **kwargs)
key = self.get_key(raw_key)

self.expire_one_memo()
memo: _AsyncMemo = self.get_memo(key, insert=insert)
if memo is None:
return await fn(*args, **kwargs)

async with memo.async_lock:
if not memo.memo_return_state.called:
memo.memo_return_state.called = True
try:
memo.memo_return_state.value = await memo.fn(*args, **kwargs)
except Exception as e:
memo.memo_return_state.raised = True
memo.memo_return_state.value = e
self.expire_one_memo()

self.bind_key_lifetime(raw_key, key)
async with memo.async_lock:
if (
(insert and not memo.memo_return_state.called) or
(update and memo.memo_return_state.value is not _MemoZeroValue)
):
memo.memo_return_state.called = True
try:
memo.memo_return_state.value = await fn(*args, **kwargs)
except Exception as e:
memo.memo_return_state.raised = True
memo.memo_return_state.value = e

return self.finalize_memo(memo=memo, key=key)
self.bind_key_lifetime(raw_key, key)

decorator.memoize = self
return self.finalize_memo(memo=memo, key=key)

return decorator
return call
return get_call

@staticmethod
def make_memo(fn, t0: Optional[float]) -> _AsyncMemo:
return _AsyncMemo(fn=fn, t0=t0)
async def insert(self, *args, **kwargs) -> Any:
return await self.get_behavior(insert=True, update=False)(fn=self.fn)(*args, **kwargs)

def update(self, *args, **kwargs) -> Callable:

async def to(value: Any) -> Any:
async def fn(*_args, **_kwargs) -> Any:
return value

return await self.get_behavior(insert=False, update=True)(fn=fn)(*args, **kwargs)

return to

def upsert(self, *args, **kwargs) -> Callable:

async def reset_call(self, *args, **kwargs) -> None:
async def to(value: Any) -> Any:
async def fn(*_args, **_kwargs) -> Any:
return value

return await self.get_behavior(insert=True, update=True)(fn=fn)(*args, **kwargs)

return to

async def remove(self, *args, **kwargs) -> None:
raw_key = await self.get_raw_key(*args, **kwargs)
key = self.get_key(raw_key)
self.reset_key(key)

def get_decorator(self) -> Callable:

async def decorator(*args, **kwargs) -> Any:
return await self.insert(*args, **kwargs)

decorator.memoize = self

return decorator

@staticmethod
def make_memo(t0: Optional[float]) -> _AsyncMemo:
return _AsyncMemo(t0=t0)


@dataclass(frozen=True)
class _SyncMemoize(_MemoizeBase):

_sync_lock: SyncLock = field(init=False, default_factory=lambda: SyncLock())

def get_raw_key(self, *args, **kwargs) -> Tuple[Hashable, ...]:
if self.keygen is None:
raw_key = self.default_keygen(*args, **kwargs)
Expand All @@ -284,46 +323,86 @@ def get_raw_key(self, *args, **kwargs) -> Tuple[Hashable, ...]:

return raw_key

def get_decorator(self) -> Callable:
def decorator(*args, **kwargs):
raw_key = self.get_raw_key(*args, **kwargs)
key = self.get_key(raw_key)
def get_behavior(self, *, insert: bool, update: bool) -> Callable:
def get_call(*, fn: Callable) -> Callable:

@wraps(self.fn)
def call(*args, **kwargs) -> Any:
raw_key = self.get_raw_key(*args, **kwargs)
key = self.get_key(raw_key)

with self._sync_lock:
memo: _SyncMemo = self.get_memo(key, insert=insert)
if memo is None:
return fn(*args, **kwargs)

self.expire_one_memo()

with memo.sync_lock:
if (
(insert and not memo.memo_return_state.called) or
(update and memo.memo_return_state.value is not _MemoZeroValue)
):
memo.memo_return_state.called = True
try:
memo.memo_return_state.value = fn(*args, **kwargs)
except Exception as e:
memo.memo_return_state.raised = True
memo.memo_return_state.value = e

self.bind_key_lifetime(raw_key, key)

return self.finalize_memo(memo=memo, key=key)

return call

return get_call

def insert(self, *args, **kwargs) -> Any:
return self.get_behavior(insert=True, update=False)(fn=self.fn)(*args, **kwargs)

def update(self, *args, **kwargs) -> Callable:

def to(value: Any) -> Any:
def fn(*_args, **_kwargs) -> Any:
return value

with self._sync_lock:
memo: _SyncMemo = self.get_memo(key)
return self.get_behavior(insert=False, update=True)(fn=fn)(*args, **kwargs)

self.expire_one_memo()
return to

with memo.sync_lock:
if not memo.memo_return_state.called:
memo.memo_return_state.called = True
try:
memo.memo_return_state.value = memo.fn(*args, **kwargs)
except Exception as e:
memo.memo_return_state.raised = True
memo.memo_return_state.value = e
def upsert(self, *args, **kwargs) -> Callable:

self.bind_key_lifetime(raw_key, key)
def to(value: Any) -> Any:
def fn(*_args, **_kwargs) -> Any:
return value

return self.finalize_memo(memo=memo, key=key)
return self.get_behavior(insert=True, update=True)(fn=fn)(*args, **kwargs)

return to

def remove(self, *args, **kwargs) -> None:
raw_key = self.get_raw_key(*args, **kwargs)
key = self.get_key(raw_key)
self.reset_key(key)

def get_decorator(self) -> Callable:

def decorator(*args, **kwargs) -> Any:
return self.insert(*args, **kwargs)

decorator.memoize = self

return decorator

@staticmethod
def make_memo(fn, t0: Optional[float]) -> _SyncMemo:
return _SyncMemo(fn=fn, t0=t0)
def make_memo(t0: Optional[float]) -> _SyncMemo:
return _SyncMemo(t0=t0)

def reset(self) -> None:
with self._sync_lock:
super().reset()

def reset_call(self, *args, **kwargs) -> None:
raw_key = self.get_raw_key(*args, **kwargs)
key = self.get_key(raw_key)
self.reset_key(key)

def reset_key(self, key: Union[int, str]) -> None:
with self._sync_lock:
super().reset_key(key)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='atools',
version='0.12.1',
version='0.13.0',
packages=find_packages(),
python_requires='>=3.6',
url='https://github.com/cevans87/atools',
Expand Down
Loading

0 comments on commit 393d957

Please sign in to comment.