Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BindReturn type hint for make_funsor #518

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 75 additions & 15 deletions funsor/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
import makefun

from funsor.instrument import debug_logged
from funsor.terms import Funsor, FunsorMeta, Variable, eager, to_funsor
from funsor.terms import (
Funsor,
FunsorMeta,
Subs,
Variable,
eager,
substitute,
to_funsor,
)
from funsor.util import as_callable


Expand Down Expand Up @@ -137,7 +145,9 @@ def _get_dependent_args(fields, hints, args):
return {
name: arg if isinstance(hint, Value) else arg.output
for name, arg, hint in zip(fields, args, hints)
if hint in (Funsor, Bound) or isinstance(hint, (Has, Value))
if hint in (Funsor, Bound)
or isinstance(hint, (Has, Value))
or (isinstance(hint, Fresh) and name in hint.args)
}


Expand Down Expand Up @@ -179,19 +189,41 @@ def Unflatten(
for name, hint in input_types.items():
if not (hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value, Has))):
raise TypeError(f"Invalid type hint {name}: {hint}")
if any(
isinstance(hint, Fresh) and arg in hint.args
for arg, hint in input_types.items()
):
input_types["bind_return"] = Value[frozenset]

def new_fn(*args):
args, bind_return = args[:-1], args[-1]
result = fn(*args)
return Subs(result, bind_return)

else:
new_fn = fn

output_type = input_types.pop("return")
hints = tuple(input_types.values())

class ResultMeta(FunsorMeta):
def __call__(cls, *args):
def __call__(cls, *args, bind_return=None):
args = list(args)

# Bind-and-return variables
if bind_return is None:
bind_return = frozenset(
(arg, arg)
for hint, arg, arg_name in zip(hints, args, cls._ast_fields)
if isinstance(hint, Fresh) and arg_name in hint.args
)

# Compute domains of bound variables.
for i, (name, arg) in enumerate(zip(cls._ast_fields, args)):
hint = input_types[name]
if hint is Funsor or isinstance(hint, Has): # TODO support domains
args[i] = to_funsor(arg)
elif hint is Bound:
elif hint is Bound or (isinstance(hint, Fresh) and name in hint.args):
for other in args:
if isinstance(other, Funsor):
domain = other.inputs.get(arg, None)
Expand All @@ -209,21 +241,32 @@ def __call__(cls, *args):

# Compute domains of fresh variables.
dependent_args = _get_dependent_args(cls._ast_fields, hints, args)
for i, (hint, arg) in enumerate(zip(hints, args)):
if isinstance(hint, Fresh):
for i, (hint, arg, arg_name) in enumerate(
zip(hints, args, cls._ast_fields)
):
if isinstance(hint, Fresh) and arg_name in hint.args:
domain = hint(**dependent_args)
args[i] = to_funsor(arg.name, domain)
elif isinstance(hint, Fresh):
domain = hint(**dependent_args)
args[i] = to_funsor(arg, domain)

# Append bind_return to args
if bind_return:
args.append(bind_return)
return super().__call__(*args)

@makefun.with_signature(
"__init__({})".format(", ".join(["self"] + list(input_types)))
)
def __init__(self, **kwargs):
args = tuple(kwargs[k] for k in self._ast_fields)
bind_return = dict(kwargs.get("bind_return", dict()))
dependent_args = _get_dependent_args(self._ast_fields, hints, args)
output = output_type(**dependent_args)
inputs = OrderedDict()
bound = {}
fresh = frozenset()
for hint, arg, arg_name in zip(hints, args, self._ast_fields):
if hint is Funsor:
assert isinstance(arg, Funsor)
Expand All @@ -232,28 +275,45 @@ def __init__(self, **kwargs):
assert isinstance(arg, Funsor)
inputs.update(arg.inputs)
for name in hint.bound:
if kwargs[name] not in arg.input_vars:
if kwargs[name].name not in arg.inputs:
warnings.warn(
f"Argument {arg_name} is missing bound variable {kwargs[name]} from argument {name}."
f"Are you sure {name} will always appear in {arg_name}?",
SyntaxWarning,
)
for hint, arg in zip(hints, args):
for hint, arg, arg_name in zip(hints, args, self._ast_fields):
if hint is Bound:
bound[arg.name] = inputs.pop(arg.name)
elif isinstance(hint, Fresh) and arg_name in hint.args:
bound[arg.name] = inputs.pop(arg.name)
inputs[bind_return[arg.name]] = arg.output
fresh |= frozenset({bind_return[arg.name]})
for hint, arg in zip(hints, args):
if isinstance(hint, Fresh):
for k, d in arg.inputs.items():
if k not in bound:
inputs[k] = d
fresh = frozenset()
if arg.name not in bound:
inputs[arg.name] = arg.output
fresh |= frozenset({arg.name})
Funsor.__init__(self, inputs, output, fresh, bound)
for name, arg in zip(self._ast_fields, args):
if name == "bind_return":
arg = dict(arg)
setattr(self, name, arg)

def _alpha_convert(self, alpha_subs):
alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()}
return Funsor._alpha_convert(self, alpha_subs)
result = []
new_alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()}
for hint, value, arg_name in zip(hints, self._ast_values, self._ast_fields):
if isinstance(hint, Fresh) and arg_name in hint.args:
result.append(to_funsor(alpha_subs[value.name], value.output))
elif arg_name == "bind_return":
result.append(
frozenset(
(alpha_subs.get(k, k), v) for k, v in self.bind_return.items()
)
)
else:
result.append(substitute(value, new_alpha_subs))
return tuple(result)

name = _get_name(fn)
ResultMeta.__name__ = f"{name}Meta"
Expand All @@ -263,7 +323,7 @@ def _alpha_convert(self, alpha_subs):
pattern = (Result,) + tuple(
_hint_to_pattern(input_types[k]) for k in Result._ast_fields
)
eager.register(*pattern)(_erase_types(fn))
eager.register(*pattern)(_erase_types(new_fn))
return Result


Expand Down
41 changes: 41 additions & 0 deletions test/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,44 @@ def MatMul(
# To preserve extensionality, should only error on reflect
xy = MatMul(x, y, "b")
check_funsor(xy, {"a": Bint[3], "c": Bint[4], "d": Bint[3]}, Real)


def test_unroll():
@make_funsor
def Unroll(
x: Has[{"ax"}], # noqa: F821
ax: Fresh[lambda ax, k: Bint[ax.size - k + 1]],
k: Value[int],
kernel: Fresh[lambda k: Bint[k]],
) -> Fresh[lambda x: x]:
return x(**{ax.name: ax + kernel})

x = random_tensor(OrderedDict(a=Bint[5]))
with reflect:
y = Unroll(x, "a", 2, "kernel")
assert y.fresh == frozenset({"a", "kernel"})
assert all(bound in y.x.inputs and "__BOUND" in bound for bound in y.bound)
check_funsor(y, {"a": Bint[5 - 2 + 1], "kernel": Bint[2]}, Real)
z = reinterpret(y)
assert isinstance(z, Tensor)
check_funsor(z, {"a": Bint[5 - 2 + 1], "kernel": Bint[2]}, Real)


def test_softmax():
@make_funsor
def Softmax(
x: Has[{"ax"}], # noqa: F821
ax: Fresh[lambda ax: ax],
) -> Fresh[lambda x: x]:
y = x - x.reduce(ops.logaddexp, ax)
return y.exp()

x = random_tensor(OrderedDict(a=Bint[3], b=Bint[4]))
with reflect:
y = Softmax(x, "a")
assert y.fresh == frozenset({"a"})
assert all(bound in y.x.inputs and "__BOUND" in bound for bound in y.bound)
check_funsor(y, {"a": Bint[3], "b": Bint[4]}, Real)
z = reinterpret(y)
assert isinstance(z, Tensor)
check_funsor(z, {"a": Bint[3], "b": Bint[4]}, Real)