Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logcdf for CensoredRV #6894

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size
from pymc.util import check_dist_not_registered
from pymc.logprob.abstract import _logcdf


class CensoredRV(SymbolicRandomVariable):
Expand Down Expand Up @@ -148,3 +149,35 @@ def moment_censored(op, rv, dist, lower, upper):
)
moment = pt.full_like(dist, moment)
return moment

@_logcdf.register(CensoredRV)
def censored_logcdf(op, value, *inputs, **kwargs):
*rv_inputs, lower, upper, rng = inputs
rv_inputs = [rng, *rv_inputs]

base_rv_op = op.base_rv_op
logcdf_cens = _logcdf(base_rv_op, value, *rv_inputs, **kwargs)

lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs)
upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs)
aadya940 marked this conversation as resolved.
Show resolved Hide resolved

is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))

#If function is left censored, set logcdf to -np,inf
if is_lower_bounded:
logcdf_trunc = pt.switch(value < lower, -np.inf, logcdf_cens)
aadya940 marked this conversation as resolved.
Show resolved Hide resolved

#If function is right censored, set logcdf to 0
if is_upper_bounded:
logcdf_trunc = pt.switch(value <= upper, logcdf_cens, 0.0)

#If in domain, set logcdf as if uncensored
if is_lower_bounded and is_upper_bounded:
logcdf_trunc = check_parameters(
logcdf_cens,
pt.le(lower, upper),
msg="lower_bound <= upper_bound",
)

return logcdf_cens
43 changes: 43 additions & 0 deletions tests/distributions/test_censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pymc as pm

from pymc.distributions.shape_utils import change_dist_size
from pymc.logprob.basic import logcdf


class TestCensored:
Expand Down Expand Up @@ -110,3 +111,45 @@ def test_dist_broadcasted_by_lower_upper(self):
pm.Normal.dist(size=(3, 4, 2)), lower=np.zeros((2,)), upper=np.zeros((4, 2))
)
assert tuple(x.owner.inputs[0].shape.eval()) == (3, 4, 2)

@pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)])
@pytest.mark.parametrize("op_type", ["icdf", "rejection"])
def test_censoring_continuous_logcdf(op_type, lower, upper):
loc = 0.15
scale = 10
op = icdf_normal if op_type == "icdf" else rejection_normal

x = op(loc, scale, name="x")
xt = pm.Censored.dist(x, lower=lower, upper=upper)
assert isinstance(xt.owner.op, pm.CensoredRV)

xt_vv = xt.clone()
xt_logcdf_fn = pytensor.function([xt_vv], logcdf(xt, xt_vv))

for bound in (lower, upper):
if np.isinf(bound):
return
for offset in (-1, 0, 1):
test_xt_v = bound + offset
assert xt_logcdf_fn(test_xt_v) is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to test the numerical outputs are correct (same for the other test)



@pytest.mark.parametrize("lower, upper", [(2, np.inf), (2, 5), (-np.inf, 5)])
@pytest.mark.parametrize("op_type", ["icdf", "rejection"])
def test_censoring_discrete_logcdf(op_type, lower, upper):
p = 0.7
op = icdf_geometric if op_type == "icdf" else rejection_geometric

x = op(p, name="x")
xt = pm.Censored.dist(x, lower=lower, upper=upper)
assert isinstance(xt.owner.op, pm.CensoredRV)

xt_vv = xt.clone()
xt_logcdf_fn = pytensor.function([xt_vv], logcdf(xt, xt_vv))

for bound in (lower, upper):
if np.isinf(bound):
continue
for offset in (-1, 0, 1):
test_xt_v = bound + offset
assert xt_logcdf_fn(test_xt_v) is not None