Skip to content

Commit

Permalink
enable more robust multiple dispatch with plum
Browse files Browse the repository at this point in the history
  • Loading branch information
seeM committed Jul 1, 2022
1 parent 06922b7 commit 6df0281
Show file tree
Hide file tree
Showing 10 changed files with 385 additions and 1,038 deletions.
6 changes: 2 additions & 4 deletions fastcore/_nbdev.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,8 @@
"do_request": "03b_net.ipynb",
"start_server": "03b_net.ipynb",
"start_client": "03b_net.ipynb",
"lenient_issubclass": "04_dispatch.ipynb",
"sorted_topologically": "04_dispatch.ipynb",
"TypeDispatch": "04_dispatch.ipynb",
"DispatchReg": "04_dispatch.ipynb",
"FastFunction": "04_dispatch.ipynb",
"FastDispatcher": "04_dispatch.ipynb",
"typedispatch": "04_dispatch.ipynb",
"retain_meta": "04_dispatch.ipynb",
"default_set_meta": "04_dispatch.ipynb",
Expand Down
2 changes: 2 additions & 0 deletions fastcore/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,8 @@ def copy_func(f):
fn = FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
fn.__kwdefaults__ = f.__kwdefaults__
fn.__dict__.update(f.__dict__)
fn.__annotations__.update(f.__annotations__)
fn.__qualname__ = f.__qualname__
return fn

# Cell
Expand Down
202 changes: 72 additions & 130 deletions fastcore/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,154 +4,96 @@
from __future__ import annotations


__all__ = ['lenient_issubclass', 'sorted_topologically', 'TypeDispatch', 'DispatchReg', 'typedispatch', 'cast',
'retain_meta', 'default_set_meta', 'retain_type', 'retain_types', 'explode_types']
__all__ = ['FastFunction', 'FastDispatcher', 'typedispatch', 'cast', 'retain_meta', 'default_set_meta', 'retain_type',
'retain_types', 'explode_types']

# Cell
#nbdev_comment from __future__ import annotations
from .imports import *
from .foundation import *
from .utils import *
from .meta import delegates

from collections import defaultdict
from plum import Function, Dispatcher

# Cell
def lenient_issubclass(cls, types):
"If possible return whether `cls` is a subclass of `types`, otherwise return False."
if cls is object and types is not object: return False # treat `object` as highest level
try: return isinstance(cls, types) or issubclass(cls, types)
except: return False
def _eval_annotations(f):
"Evaluate future annotations before passing to plum to support backported union operator `|`"
f = copy_func(f)
for k, v in type_hints(f).items(): f.__annotations__[k] = Union[v] if isinstance(v, tuple) else v
return f

# Cell
def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False):
"Return a new list containing all items from the iterable sorted topologically"
l,res = L(list(iterable)),[]
for _ in range(len(l)):
t = l.reduce(lambda x,y: y if cmp(y,x) else x)
res.append(t), l.remove(t)
return res[::-1] if reverse else res
def _pt_repr(o):
"Concise repr of plum types"
n = type(o).__name__
if n == 'Tuple': return f"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]"
if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]'
if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]'
if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]'
if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]'
if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types())))
assert len(o.get_types()) == 1
return o.get_types()[0].__name__

# Cell
def _chk_defaults(f, ann):
pass
# Implementation removed until we can figure out how to do this without `inspect` module
# try: # Some callables don't have signatures, so ignore those errors
# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)]
# if any(p.default!=inspect.Parameter.empty for p in params):
# warn(f"{f.__name__} has default params. These will be ignored.")
# except ValueError: pass

# Cell
def _p2_anno(f):
"Get the 1st 2 annotations of `f`, defaulting to `object`"
hints = type_hints(f)
ann = [o for n,o in hints.items() if n!='return']
if callable(f): _chk_defaults(f, ann)
while len(ann)<2: ann.append(object)
return ann[:2]
class FastFunction(Function):
def __repr__(self):
return '\n'.join(f"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}"
for s, (f, r) in self.methods.items())

# Cell
class _TypeDict:
def __init__(self): self.d,self.cache = {},{}

def _reset(self):
self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)}
self.cache = {}

def add(self, t, f):
"Add type `t` and function `f`"
if not isinstance(t, tuple): t = tuple(L(union2tuple(t)))
for t_ in t: self.d[t_] = f
self._reset()

def all_matches(self, k):
"Find first matching type that is a super-class of `k`"
if k not in self.cache:
types = [f for f in self.d if lenient_issubclass(k,f)]
self.cache[k] = [self.d[o] for o in types]
return self.cache[k]

def __getitem__(self, k):
"Find first matching type that is a super-class of `k`"
res = self.all_matches(k)
return res[0] if len(res) else None

def __repr__(self): return self.d.__repr__()
def first(self): return first(self.d.values())
@delegates(Function.dispatch)
def dispatch(self, f=None, **kwargs): return super().dispatch(_eval_annotations(f), **kwargs)

# Cell
class TypeDispatch:
"Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
def __init__(self, funcs=(), bases=()):
self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))
for o in L(funcs): self.add(o)
self.inst = None
self.owner = None

def add(self, f):
"Add type `t` and function `f`"
if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)
else: a0,a1 = _p2_anno(f)
t = self.funcs.d.get(a0)
if t is None:
t = _TypeDict()
self.funcs.add(a0, t)
t.add(a1, f)

def first(self):
"Get first function in ordered dict of type:func."
return self.funcs.first().first()

def returns(self, x):
"Get the return type of annotation of `x`."
return anno_ret(self[type(x)])

def _attname(self,k): return getattr(k,'__name__',str(k))
def __repr__(self):
r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, "__name__", type(v).__name__)}'
for k in self.funcs.d for l,v in self.funcs[k].d.items()]
r = r + [o.__repr__() for o in self.bases]
return '\n'.join(r)

def __call__(self, *args, **kwargs):
ts = L(args).map(type)[:2]
f = self[tuple(ts)]
if not f: return args[0]
if isinstance(f, staticmethod): f = f.__func__
elif self.inst is not None: f = MethodType(f, self.inst)
elif self.owner is not None: f = MethodType(f, self.owner)
return f(*args, **kwargs)

def __get__(self, inst, owner):
self.inst = inst
self.owner = owner
return self

def __getitem__(self, k):
"Find first matching type that is a super-class of `k`"
k = L(k)
while len(k)<2: k.append(object)
r = self.funcs.all_matches(k[0])
for t in r:
o = t[k[1]]
if o is not None: return o
for base in self.bases:
res = base[k]
if res is not None: return res
return None
def __getitem__(self, ts):
"Return the most-specific matching method with fewest parameters"
ts = L(ts)
nargs = min(len(o) for o in self.methods.keys())
while len(ts) < nargs: ts.append(object)
return self.invoke(*ts)

# Cell
class DispatchReg:
"A global registry for `TypeDispatch` objects keyed by function name"
def __init__(self): self.d = defaultdict(TypeDispatch)
def __call__(self, f):
if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}'
else: nm = f'{f.__qualname__}'
if isinstance(f, classmethod): f=f.__func__
self.d[nm].add(f)
return self.d[nm]

typedispatch = DispatchReg()
class FastDispatcher(Dispatcher):
def _get_function(self, method, owner):
"Adapted from `Dispatcher._get_function` to use `FastFunction`"
name = method.__name__
if owner:
if owner not in self._classes: self._classes[owner] = {}
namespace = self._classes[owner]
else: namespace = self._functions
if name not in namespace: namespace[name] = FastFunction(method, owner=owner)
return namespace[name]

@delegates(Dispatcher.__call__, but='method')
def __call__(self, f, **kwargs): return super().__call__(_eval_annotations(f), **kwargs)

def _to(self, cls, nm, f, **kwargs):
nf = copy_func(f)
nf.__qualname__ = f'{cls.__name__}.{nm}' # plum uses __qualname__ to infer f's owner
pf = self(nf, **kwargs)
# plum uses __set_name__ to resolve a plum.Function's owner
# since we assign after class creation, __set_name__ must be called directly
# source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__
pf.__set_name__(cls, nm)
pf = pf.resolve()
setattr(cls, nm, pf)
return pf

def to(self, cls):
"Decorator: dispatch `f` to `cls.f`"
def _inner(f, **kwargs):
nm = f.__name__
# check __dict__ to avoid inherited methods but use getattr so pf.__get__ is called, which plum relies on
if nm in cls.__dict__:
pf = getattr(cls, nm)
if not hasattr(pf, 'dispatch'): pf = self._to(cls, nm, pf, **kwargs)
pf.dispatch(f)
else: pf = self._to(cls, nm, f, **kwargs)
return pf
return _inner

typedispatch = FastDispatcher()

# Cell
#nbdev_comment _all_=['cast']
Expand Down
10 changes: 10 additions & 0 deletions fastcore/imports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys,os,re,typing,itertools,operator,functools,math,warnings,functools,io,enum

from copy import copy
from operator import itemgetter,attrgetter
from warnings import warn
from typing import Iterable,Generator,Sequence,Iterator,List,Set,Dict,Union,Optional,Tuple
Expand All @@ -14,6 +15,15 @@
MethodDescriptorType = type(str.join)
from types import BuiltinFunctionType,BuiltinMethodType,MethodType,FunctionType,SimpleNamespace

#Patch autoreload (if its loaded) to work with plum
try: from IPython import get_ipython
except ImportError: pass
else:
ip = get_ipython()
if ip is not None and 'IPython.extensions.storemagic' in ip.extension_manager.loaded:
from plum.autoreload import activate
activate()

NoneType = type(None)
string_classes = (str,bytes)

Expand Down
49 changes: 30 additions & 19 deletions fastcore/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,33 @@
from .utils import *
from .dispatch import *
import inspect
from plum import add_conversion_method

# Cell
_tfm_methods = 'encodes','decodes','setups'

def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)

class _TfmDict(dict):
def __setitem__(self,k,v):
if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)
if k not in self: super().__setitem__(k,TypeDispatch())
self[k].add(v)
def __setitem__(self, k, v): super().__setitem__(k, typedispatch(v) if _is_tfm_method(k, v) else v)

# Cell
class _TfmMeta(type):
def __new__(cls, name, bases, dict):
# _TfmMeta.__call__ shadows the signature of inheriting classes, set it back
res = super().__new__(cls, name, bases, dict)
for nm in _tfm_methods:
base_td = [getattr(b,nm,None) for b in bases]
if nm in res.__dict__: getattr(res,nm).bases = base_td
else: setattr(res, nm, TypeDispatch(bases=base_td))
res.__signature__ = inspect.signature(res.__init__)
return res

def __call__(cls, *args, **kwargs):
f = args[0] if args else None
n = getattr(f,'__name__',None)
if callable(f) and n in _tfm_methods:
getattr(cls,n).add(f)
return f
return super().__call__(*args, **kwargs)
f = first(args)
n = getattr(f, '__name__', None)
if _is_tfm_method(n, f): return typedispatch.to(cls)(f)
obj = super().__call__(*args, **kwargs)
# _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
# instances of cls, fix it
if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)
return obj

@classmethod
def __prepare__(cls, name, bases): return _TfmDict()
Expand All @@ -60,13 +59,14 @@ def __init__(self, enc=None, dec=None, split_idx=None, order=None):
self.init_enc = enc or dec
if not self.init_enc: return

self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch()
def identity(x): return x
for n in _tfm_methods: setattr(self,n,FastFunction(identity).dispatch(identity))
if enc:
self.encodes.add(enc)
self.encodes.dispatch(enc)
self.order = getattr(enc,'order',self.order)
if len(type_hints(enc)) > 0: self.input_types = union2tuple(first(type_hints(enc).values()))
self._name = _get_name(enc)
if dec: self.decodes.add(dec)
if dec: self.decodes.dispatch(dec)

@property
def name(self): return getattr(self, '_name', _get_name(self))
Expand All @@ -85,13 +85,24 @@ def _call(self, fn, x, split_idx=None, **kwargs):
def _do_call(self, f, x, **kwargs):
if not _is_tuple(x):
if f is None: return x
ret = f.returns(x) if hasattr(f,'returns') else None
return retain_type(f(x, **kwargs), x, ret)
ts = [type(self),type(x)] if hasattr(f,'instance') else [type(x)]
_, ret = f.resolve_method(*ts)
ret = ret._type
# plum reads empty return annotation as object, retain_type expects it as None
if ret is object: ret = None
return retain_type(f(x,**kwargs), x, ret)
res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
return retain_type(res, x)
def encodes(self, x): return x
def decodes(self, x): return x
def setups(self, dl): return dl

add_docs(Transform, decode="Delegate to <code>decodes</code> to undo transform", setup="Delegate to <code>setups</code> to set up transform")

# Cell
#Implement the Transform convention that a None return annotation disables conversion
add_conversion_method(object, NoneType, lambda x: x)

# Cell
class InplaceTransform(Transform):
"A `Transform` that modifies in-place and just returns whatever it's passed"
Expand Down
Loading

0 comments on commit 6df0281

Please sign in to comment.