Skip to content

Commit

Permalink
Fix ValueVariable checks
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 11, 2023
1 parent cbd0e2f commit 906c10d
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 16 deletions.
4 changes: 2 additions & 2 deletions aeppl/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def find_measurable_clips(
if not (
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and not isinstance(base_var, ValuedVariable)
and not isinstance(base_var.owner.op, ValuedVariable)
):
return None

Expand Down Expand Up @@ -199,7 +199,7 @@ def construct_measurable_rounding(
if not (
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and not isinstance(base_var, ValuedVariable)
and not isinstance(base_var.owner.op, ValuedVariable)
# Rounding only makes sense for continuous variables
and base_var.dtype.startswith("float")
):
Expand Down
2 changes: 1 addition & 1 deletion aeppl/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]:
if not (
base_rv.owner
and isinstance(base_rv.owner.op, MeasurableVariable)
and not isinstance(base_rv, ValuedVariable)
and not isinstance(base_rv.owner.op, ValuedVariable)
):
return None # pragma: no cover

Expand Down
6 changes: 4 additions & 2 deletions aeppl/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def mixture_replace(fgraph, node):

mixture_res, join_axis = get_stack_mixture_vars(node)

if mixture_res is None or any(isinstance(rv, ValuedVariable) for rv in mixture_res):
if mixture_res is None or any(
rv.owner and isinstance(rv.owner.op, ValuedVariable) for rv in mixture_res
):
return None # pragma: no cover

mixing_indices = node.inputs[1:]
Expand Down Expand Up @@ -314,7 +316,7 @@ def switch_mixture_replace(fgraph, node):
if not (
component_rv.owner
and isinstance(component_rv.owner.op, MeasurableVariable)
and not isinstance(component_rv, ValuedVariable)
and not isinstance(component_rv.owner.op, ValuedVariable)
):
return None
new_node = assign_custom_measurable_outputs(component_rv.owner)
Expand Down
2 changes: 1 addition & 1 deletion aeppl/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def incsubtensor_rv_replace(fgraph, node):
if not (
base_rv_var.owner
and isinstance(base_rv_var.owner.op, MeasurableVariable)
and not isinstance(base_rv_var, ValuedVariable)
and not isinstance(base_rv_var.owner.op, ValuedVariable)
):
return None # pragma: no cover

Expand Down
2 changes: 1 addition & 1 deletion aeppl/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def update_scan_value_vars(
"""

# if not any(isinstance(out, ValuedVariable) for out in node.outputs):
# if not any(isinstance(out.owner.op, ValuedVariable) for out in node.outputs):
# return new_node.outputs

# Get any `Subtensor` outputs that have been applied to outputs of this
Expand Down
4 changes: 2 additions & 2 deletions aeppl/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def find_measurable_stacks(
if not all(
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and not isinstance(base_var, ValuedVariable)
and not isinstance(base_var.owner.op, ValuedVariable)
for base_var in base_vars
):
return None # pragma: no cover
Expand Down Expand Up @@ -178,7 +178,7 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[List[MeasurableDimShuf
if not (
base_var.owner
and isinstance(base_var.owner.op, RandomVariable)
and not isinstance(base_var, ValuedVariable)
and not isinstance(base_var.owner.op, ValuedVariable)
):
return None # pragma: no cover

Expand Down
14 changes: 7 additions & 7 deletions aeppl/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def construct_elemwise_transform(
for idx, inp in enumerate(node.inputs)
if inp.owner
and isinstance(inp.owner.op, MeasurableVariable)
and not isinstance(inp, ValuedVariable)
and not isinstance(inp.owner.op, ValuedVariable)
]

if len(measurable_inputs) != 1:
Expand All @@ -562,19 +562,19 @@ def expand(var: TensorVariable) -> List[TensorVariable]:
if (
var.owner
and not isinstance(var.owner.op, MeasurableVariable)
and not isinstance(var, ValuedVariable)
and not isinstance(var.owner.op, ValuedVariable)
):
new_vars.extend(reversed(var.owner.inputs))

return new_vars

if any(
ancestor_node
for ancestor_node in walk(other_inputs, expand, False)
var
for var in walk(other_inputs, expand, False)
if (
ancestor_node.owner
and isinstance(ancestor_node.owner.op, MeasurableVariable)
and not isinstance(ancestor_node, ValuedVariable)
var.owner
and isinstance(var.owner.op, MeasurableVariable)
and not isinstance(var.owner.op, ValuedVariable)
)
):
return None
Expand Down
15 changes: 15 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ def test_transform_measurable_sub():
assert np.isclose(z_logp_fn(7.3), sp.stats.norm.logpdf(7.3, loc=-4.0))


@pytest.mark.xfail(reason="This needs to be reconsidered")
def test_transform_reused_measurable():
srng = at.random.RandomStream(0)

Expand All @@ -804,3 +805,17 @@ def test_transform_reused_measurable():
exp_res = sp.stats.lognorm(s=1).logpdf(z_val) + sp.stats.norm().logpdf(z_val)

np.testing.assert_allclose(logp_fn(z_val), exp_res)


def test_transform_sub_valued():
"""Test the case when one of the transformed inputs is a `ValuedVariable`."""
srng = at.random.RandomStream(0)

A_rv = srng.normal(1.0, name="A")
X_rv = srng.normal(1.0, name="X")
Z_rv = A_rv - X_rv

logp, (z_vv, a_vv) = joint_logprob(Z_rv, A_rv)
z_logp_fn = aesara.function([z_vv, a_vv], logp)
exp_logp = sp.stats.norm.logpdf(5.0 - 7.3, 1.0) + sp.stats.norm.logpdf(5.0, 1.0)
assert np.isclose(z_logp_fn(7.3, 5.0), exp_logp)

0 comments on commit 906c10d

Please sign in to comment.