Skip to content

Commit

Permalink
Add function to support negation
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Nov 26, 2023
1 parent ace4644 commit 69f1b32
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 26 deletions.
14 changes: 3 additions & 11 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
_logprob_helper,
)
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import check_negation
from pymc.logprob.utils import find_negated_var
from pymc.math import logdiffexp
from pymc.pytensorf import constant_fold

Expand Down Expand Up @@ -189,18 +189,10 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[
if not (base_var.owner is not None and isinstance(base_var.owner.op, Elemwise)):
return None

if len(base_var.owner.inputs) == 2:
if base_var.owner.inputs[0] is None:
base_rv = base_var.owner.inputs[1]
scalar_constant = base_var.owner.inputs[0]
else:
base_rv = base_var.owner.inputs[0]
scalar_constant = base_var.owner.inputs[1]
else:
return None
base_rv = find_negated_var(base_var)

# negation is rv * (-1). Hence the scalar_op must be Mul
if check_negation(base_var.owner.op.scalar_op, scalar_constant) is False:
if base_rv is None:
return None

# Non-univariate distributions and non-RVs must be rejected
Expand Down
8 changes: 6 additions & 2 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@
_logprob_helper,
)
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import CheckParameterValue, check_potential_measurability, check_negation
from pymc.logprob.utils import (
CheckParameterValue,
check_potential_measurability,
find_negated_var,
)


class Transform(abc.ABC):
Expand Down Expand Up @@ -442,7 +446,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
# Do not apply rewrite to discrete variables
if measurable_input.type.dtype.startswith("int"):
if not (
check_negation(node.op.scalar_op, node.inputs[0]) or isinstance(node.op.scalar_op, Add)
find_negated_var(measurable_output) is not None or isinstance(node.op.scalar_op, Add)
):
return None
if not measurable_output.type.dtype.startswith("int"):
Expand Down
23 changes: 18 additions & 5 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,29 @@ def expand(r):
}


def check_negation(scalar_op, scalar_constant):
"""Make sure that the base variable invovles a multiplication with -1"""
def find_negated_var(base_var):
"""Make sure that the base variable involves a multiplication with -1"""

scalar_op = base_var.owner.op.scalar_op

if len(base_var.owner.inputs) == 2:
if base_var.owner.inputs[0] is None:
base_rv = base_var.owner.inputs[0]
scalar_constant = base_var.owner.inputs[1]

Check warning on line 319 in pymc/logprob/utils.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/utils.py#L318-L319

Added lines #L318 - L319 were not covered by tests
else:
base_rv = base_var.owner.inputs[1]
scalar_constant = base_var.owner.inputs[0]
else:
return None

try:
if not (
isinstance(scalar_op, Mul)
and get_underlying_scalar_constant_value(scalar_constant) == -1
):
return False
return None

except NotScalarConstantError:
return False
return None

return True
return base_rv
2 changes: 1 addition & 1 deletion tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def test_min_max_bernoulli():
p = 0.7
q = 1 - p
n = 3
x = pm.Bernoulli.dist(p=p, shape=(n,))
x = pm.Bernoulli.dist(name="x", p=p, shape=(n,))
value = pt.scalar("value", dtype=int)

max_logp_fn = pytensor.function([value], pm.logp(pt.max(x), value))
Expand Down
14 changes: 7 additions & 7 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import pymc as pm

from pymc.distributions.continuous import Cauchy, ChiSquared
from pymc.distributions.discrete import Bernoulli
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
from pymc.logprob.transforms import (
ArccoshTransform,
Expand Down Expand Up @@ -688,9 +689,8 @@ def test_discrete_rv_unary_transform_fails():
conditional_logp({y_rv: y_rv.clone()})


# add 2 tests. One fir supported and one for unsupported
def test_discrete_rv_multinary_transform():
y_rv = 5 + pt.random.poisson(1)
y_rv = 5 * pt.random.poisson(1)
with pytest.raises(RuntimeError, match="could not be derived"):
conditional_logp({y_rv: y_rv.clone()})

Expand All @@ -709,10 +709,10 @@ def test_invalid_broadcasted_transform_rv_fails():

def test_discrete_measurable_cdf_icdf():
p = 0.7
rv = -pm.Bernoulli.dist(p=p)
rv = -Bernoulli.dist(p=p)

# A negated Bernoulli has pmf {p if x == -1; 1-p if x == 0; 0 otherwise}
assert pm.logp(rv, -2).eval() == -np.inf # Correct
assert pm.logp(rv, -1).eval() == np.log(p) # Correct
assert pm.logp(rv, 0).eval() == np.log(1 - p) # Correct
assert pm.logp(rv, 1).eval() == -np.inf # Correct
assert logp(rv, -2).eval() == -np.inf # Correct
assert logp(rv, -1).eval() == np.log(p) # Correct
assert logp(rv, 0).eval() == np.log(1 - p) # Correct
assert logp(rv, 1).eval() == -np.inf # Correct

0 comments on commit 69f1b32

Please sign in to comment.