Skip to content

Commit

Permalink
Associate transforms with random variables
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 17, 2022
1 parent c64e12f commit eb55106
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 76 deletions.
16 changes: 14 additions & 2 deletions aeppl/joint_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,16 @@ def conditional_logprob(

fgraph, rv_values, _ = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)

# The interface for transformations assumes that the value variables are in
# the transformed space. To get the correct `shape` and `dtype` for the
# value variables we return we need to apply the forward transformation to
# our RV copies, and return the type of the resulting variable as a value
# variable.
vv_remapper = {}
if extra_rewrites is not None:
extra_rewrites.rewrite(fgraph)
extra_rewrites.add_requirements(fgraph, {**original_rv_values, **realized})
extra_rewrites.apply(fgraph)
vv_remapper = fgraph.values_to_untransformed

rv_remapper = fgraph.preserve_rv_mappings

Expand Down Expand Up @@ -145,6 +153,7 @@ def conditional_logprob(
q = deque(fgraph.toposort())

logprob_vars = {}
value_variables = {}

while q:
node = q.popleft()
Expand Down Expand Up @@ -201,6 +210,9 @@ def conditional_logprob(

logprob_vars[q_rv] = q_logprob_var

q_value_var = vv_remapper.get(q_value_var, q_value_var)
value_variables[q_rv] = q_value_var

# Recompute test values for the changes introduced by the
# replacements above.
if config.compute_test_value != "off":
Expand All @@ -213,7 +225,7 @@ def conditional_logprob(
f"The logprob terms of the following random variables could not be derived: {missing_value_terms}"
)

return logprob_vars, list(original_rv_values.values())
return logprob_vars, [value_variables[rv] for rv in original_rv_values.keys()]


def joint_logprob(
Expand Down
55 changes: 43 additions & 12 deletions aeppl/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,14 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
"""

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
values_to_untransformed = getattr(fgraph, "values_to_untransformed", None)
values_to_transforms = getattr(fgraph, "values_to_transforms", None)

if rv_map_feature is None or values_to_transforms is None:
if (
rv_map_feature is None
or values_to_transforms is None
or values_to_untransformed is None
):
return None # pragma: no cover

try:
Expand All @@ -133,6 +138,7 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
except ValueError:
return None

value_var: TensorVariable
value_var = rv_map_feature.rv_values.get(rv_var, None)
if value_var is None:
return None
Expand All @@ -154,10 +160,21 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
trans_node.outputs[rv_var_out_idx].name = node.outputs[rv_var_out_idx].name

# We now assume that the old value variable represents the *transformed space*.
# This means that we need to replace all instance of the old value variable

# Since we initialize value variables as copies of the random variables,
# thus in the untransformed space, we need to apply the forward
# transformation to get value variables in the transformed space.
old_value_var: TensorVariable = transform.forward(
value_var, *trans_node.inputs
).type()
if value_var.name:
old_value_var.name = value_var.name
values_to_untransformed[value_var] = old_value_var

# We need to replace all instance of the old value variable
# with "inversely/un-" transformed versions of itself.
new_value_var = transformed_variable(
transform.backward(value_var, *trans_node.inputs), value_var
transform.backward(old_value_var, *trans_node.inputs), old_value_var
)
if value_var.name and getattr(transform, "name", None):
new_value_var.name = f"{value_var.name}_{transform.name}"
Expand All @@ -170,16 +187,24 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:


class TransformValuesMapping(Feature):
r"""A `Feature` that maintains a map between value variables and their transforms."""
r"""A `Feature` that maintains a map between value variables and their transforms as
well as between value variables and their transformed counterpart.
"""

def __init__(self, values_to_transforms):
self.values_to_transforms = values_to_transforms
self.values_to_untransformed: Dict[TensorVariable, TensorVariable] = {}

def on_attach(self, fgraph):
if hasattr(fgraph, "values_to_transforms"):
raise AlreadyThere()

fgraph.values_to_transforms = self.values_to_transforms
fgraph.values_to_untransformed = self.values_to_untransformed

def update_untransformed_value(self, value, untransformed_value):
self.values_to_untransformed[value] = untransformed_value


class TransformValuesRewrite(GraphRewriter):
Expand All @@ -189,25 +214,31 @@ class TransformValuesRewrite(GraphRewriter):

def __init__(
self,
values_to_transforms: Dict[
rvs_to_transforms: Dict[
TensorVariable, Union[RVTransform, DefaultTransformSentinel, None]
],
):
"""
Parameters
==========
values_to_transforms
Mapping between value variables and their transformations. Each
value variable can be assigned one of `RVTransform`,
``DEFAULT_TRANSFORM``, or ``None``. If a transform is not specified
for a specific value variable it will not be transformed.
Mapping between random variables and their transformations. Each
random variable can be assigned one of `RVTransform`,
``DEFAULT_TRANSFORM``, or ``None``. Random variables with no
transform specified remain unchanged.
"""

self.values_to_transforms = values_to_transforms
self.rvs_to_transforms = rvs_to_transforms

def add_requirements(self, fgraph):
values_transforms_feature = TransformValuesMapping(self.values_to_transforms)
def add_requirements(
self, fgraph, rv_to_values: Dict[TensorVariable, TensorVariable]
):
values_to_transforms = {
rv_to_values[rv]: transform
for rv, transform in self.rvs_to_transforms.items()
}
values_transforms_feature = TransformValuesMapping(values_to_transforms)
fgraph.attach_feature(values_transforms_feature)

def apply(self, fgraph: FunctionGraph):
Expand Down
6 changes: 2 additions & 4 deletions tests/test_censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,8 @@ def test_clip_transform():
x_rv = at.random.normal(0.5, 1)
cens_x_rv = at.clip(x_rv, 0, x_rv)

cens_x_vv = cens_x_rv.clone()

transform = TransformValuesRewrite({cens_x_vv: LogTransform()})
logp, _ = joint_logprob(realized={cens_x_rv: cens_x_vv}, extra_rewrites=transform)
transform = TransformValuesRewrite({cens_x_rv: LogTransform()})
logp, (cens_x_vv,) = joint_logprob(cens_x_rv, extra_rewrites=transform)

cens_x_vv_testval = -1
obs_logp = logp.eval({cens_x_vv: cens_x_vv_testval})
Expand Down
2 changes: 1 addition & 1 deletion tests/test_joint_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,6 @@ def test_multiple_rvs_to_same_value_raises():
x = x_rv1.type()
x.name = "x"

msg = "More than one logprob factor was assigned to the value variable x"
msg = "More than one logprob factor was assigned to the random variable x"
with pytest.raises(ValueError, match=msg):
joint_logprob(realized={x_rv1: x, x_rv2: x})
5 changes: 4 additions & 1 deletion tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def create_mix_model(size, axis):
I_rv = env["I_rv"]
M_rv = env["M_rv"]

with pytest.raises(RuntimeError, match="could not be derived: {m}"):
with pytest.raises(
RuntimeError,
match="The logprob terms of the following random variables could not be derived: {M}",
):
conditional_logprob(M_rv, I_rv, X_rv)

with pytest.raises(NotImplementedError):
Expand Down
97 changes: 41 additions & 56 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest
import scipy as sp
import scipy.special
from aesara.graph.basic import equal_computations
from aesara.graph.fg import FunctionGraph
from numdifftools import Jacobian

Expand All @@ -22,7 +21,6 @@
TransformValuesMapping,
TransformValuesRewrite,
_default_transformed_rv,
transformed_variable,
)
from tests.utils import assert_no_rvs

Expand Down Expand Up @@ -176,15 +174,13 @@ def test_transformed_logprob(at_dist, dist_params, sp_dist, size):

a = at_dist(*dist_params, size=size)
a.name = "a"
a_value_var = at.tensor(a.dtype, shape=(None,) * a.ndim)
a_value_var.name = "a_value"

b = at.random.normal(a, 1.0)
b.name = "b"

transform_rewrite = TransformValuesRewrite({a_value_var: DEFAULT_TRANSFORM})
res, (b_value_var,) = joint_logprob(
b, realized={a: a_value_var}, extra_rewrites=transform_rewrite
transform_rewrite = TransformValuesRewrite({a: DEFAULT_TRANSFORM})
res, (b_value_var, a_value_var) = joint_logprob(
b, a, extra_rewrites=transform_rewrite
)

test_val_rng = np.random.RandomState(3238)
Expand Down Expand Up @@ -268,12 +264,10 @@ def a_backward_fn_(x):
@pytest.mark.parametrize("use_jacobian", [True, False])
def test_simple_transformed_logprob_nojac(use_jacobian):
X_rv = at.random.halfnormal(0, 3, name="X")
x_vv = X_rv.clone()
x_vv.name = "x"

transform_rewrite = TransformValuesRewrite({x_vv: DEFAULT_TRANSFORM})
tr_logp, _ = joint_logprob(
realized={X_rv: x_vv},
transform_rewrite = TransformValuesRewrite({X_rv: DEFAULT_TRANSFORM})
tr_logp, (x_vv,) = joint_logprob(
X_rv,
extra_rewrites=transform_rewrite,
use_jacobian=use_jacobian,
)
Expand Down Expand Up @@ -321,19 +315,17 @@ def test_hierarchical_uniform_transform():
upper_rv = at.random.uniform(9, 10, name="upper")
x_rv = at.random.uniform(lower_rv, upper_rv, name="x")

lower = lower_rv.clone()
upper = upper_rv.clone()
x = x_rv.clone()

transform_rewrite = TransformValuesRewrite(
{
lower: DEFAULT_TRANSFORM,
upper: DEFAULT_TRANSFORM,
x: DEFAULT_TRANSFORM,
lower_rv: DEFAULT_TRANSFORM,
upper_rv: DEFAULT_TRANSFORM,
x_rv: DEFAULT_TRANSFORM,
}
)
logp, _ = joint_logprob(
realized={lower_rv: lower, upper_rv: upper, x_rv: x},
logp, (lower, upper, x) = joint_logprob(
lower_rv,
upper_rv,
x_rv,
extra_rewrites=transform_rewrite,
)

Expand All @@ -346,20 +338,18 @@ def test_nondefault_transforms():
scale_rv = at.random.uniform(-1, 1, name="scale")
x_rv = at.random.normal(loc_rv, scale_rv, name="x")

loc = loc_rv.clone()
scale = scale_rv.clone()
x = x_rv.clone()

transform_rewrite = TransformValuesRewrite(
{
loc: None,
scale: LogOddsTransform(),
x: LogTransform(),
loc_rv: None,
scale_rv: LogOddsTransform(),
x_rv: LogTransform(),
}
)

logp, _ = joint_logprob(
realized={loc_rv: loc, scale_rv: scale, x_rv: x},
logp, (loc, scale, x) = joint_logprob(
loc_rv,
scale_rv,
x_rv,
extra_rewrites=transform_rewrite,
)

Expand Down Expand Up @@ -391,12 +381,11 @@ def test_default_transform_multiout():
# multiple outputs and no default output.
sd = at.linalg.svd(at.eye(1))[1][0]
x_rv = at.random.normal(0, sd, name="x")
x = x_rv.clone()

transform_rewrite = TransformValuesRewrite({x: DEFAULT_TRANSFORM})
transform_rewrite = TransformValuesRewrite({x_rv: DEFAULT_TRANSFORM})

logp, _ = joint_logprob(
realized={x_rv: x},
logp, (x,) = joint_logprob(
x_rv,
extra_rewrites=transform_rewrite,
)

Expand All @@ -412,12 +401,11 @@ def test_nonexistent_default_transform():
transform does not fail
"""
x_rv = at.random.normal(name="x")
x = x_rv.clone()

transform_rewrite = TransformValuesRewrite({x: DEFAULT_TRANSFORM})
transform_rewrite = TransformValuesRewrite({x_rv: DEFAULT_TRANSFORM})

logp, _ = joint_logprob(
realized={x_rv: x},
logp, (x,) = joint_logprob(
x_rv,
extra_rewrites=transform_rewrite,
)

Expand Down Expand Up @@ -446,9 +434,8 @@ def test_original_values_output_dict():
the logprob factor
"""
p_rv = at.random.beta(1, 1, name="p")
p_vv = p_rv.clone()

tr = TransformValuesRewrite({p_vv: DEFAULT_TRANSFORM})
tr = TransformValuesRewrite({p_rv: DEFAULT_TRANSFORM})
logp_dict, _ = conditional_logprob(p_rv, extra_rewrites=tr)

assert p_rv in logp_dict
Expand All @@ -469,29 +456,28 @@ def test_mixture_transform():
Y_rv = at.stack([Y_1_rv, Y_2_rv])[I_rv]
Y_rv.name = "Y"

logp_no_trans, (y_vv, i_vv) = joint_logprob(Y_rv, I_rv)
logp, (y_vv, i_vv) = joint_logprob(
Y_rv,
I_rv,
)

transform_rewrite = TransformValuesRewrite({y_vv: LogTransform()})
transform_rewrite = TransformValuesRewrite({Y_rv: LogOddsTransform()})

with pytest.warns(None) as record:
# This shouldn't raise any warnings
logp_trans, _ = joint_logprob(
realized={Y_rv: y_vv, I_rv: i_vv},
logp_trans, (y_vv_trans, i_vv_trans) = joint_logprob(
Y_rv,
I_rv,
extra_rewrites=transform_rewrite,
use_jacobian=False,
)

assert not record.list

# The untransformed graph should be the same as the transformed graph after
# replacing the `Y_rv` value variable with a transformed version of itself
logp_nt_fg = FunctionGraph(outputs=[logp_no_trans], clone=False)
y_trans = transformed_variable(at.exp(y_vv), y_vv)
y_trans.name = "y_log"
logp_nt_fg.replace(y_vv, y_trans)
logp_nt = logp_nt_fg.outputs[0]

assert equal_computations([logp_nt], [logp_trans])
logp_fn = aesara.function((i_vv, y_vv), logp)
logp_trans_fn = aesara.function((i_vv_trans, y_vv_trans), logp_trans)
np.isclose(logp_trans_fn(0, np.log(0.1 / 0.9)), logp_fn(0, 0.1))
np.isclose(logp_trans_fn(1, np.log(0.1 / 0.9)), logp_fn(1, 0.1))


def test_invalid_interval_transform():
Expand Down Expand Up @@ -642,11 +628,10 @@ def test_scale_transform_rv(rv_size, scale_type):
def test_transformed_rv_and_value():
y_rv = at.random.halfnormal(-1, 1, name="base_rv") + 1
y_rv.name = "y"
y_vv = y_rv.clone()

transform_rewrite = TransformValuesRewrite({y_vv: LogTransform()})
transform_rewrite = TransformValuesRewrite({y_rv: LogTransform()})

logp, _ = joint_logprob(realized={y_rv: y_vv}, extra_rewrites=transform_rewrite)
logp, (y_vv,) = joint_logprob(y_rv, extra_rewrites=transform_rewrite)
assert_no_rvs(logp)
logp_fn = aesara.function([y_vv], logp)

Expand Down

0 comments on commit eb55106

Please sign in to comment.