Skip to content

Commit

Permalink
Added icdf - logcdf consistency tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
gokuld committed Apr 29, 2023
1 parent a50eeab commit a158f7e
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 0 deletions.
77 changes: 77 additions & 0 deletions pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,83 @@ def check_selfconsistency_discrete_logcdf(
)


def check_selfconsistency_continuous_icdf(
distribution: Distribution,
paramdomains: Dict[str, Domain],
decimal: Optional[int] = None,
n_samples: int = 100,
) -> None:
"""
Check that the icdf and logcdf functions of the distribution are consistent for a sample of probability values.
"""
if decimal is None:
decimal = select_by_precision(float64=6, float32=3)

dist = create_dist_from_paramdomains(distribution, paramdomains)
value = dist.type()
value.name = "value"

dist_icdf = icdf(dist, value)
dist_icdf_fn = pytensor.function(list(inputvars(dist_icdf)), dist_icdf)

dist_logcdf = logcdf(dist, value)
dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf)

domains = paramdomains.copy()
domains["value"] = Domain(np.linspace(0.1, 1, 100))

for point in product(domains, n_samples=n_samples):
point = dict(point)
value = point.pop("value")

with pytensor.config.change_flags(mode=Mode("py")):
npt.assert_almost_equal(
value,
np.exp(dist_logcdf_fn(**point, value=dist_icdf_fn(**point, value=value))),
decimal=decimal,
err_msg=f"point: {point}, value: {value}",
)


def check_selfconsistency_discrete_icdf(
distribution: Distribution,
domain: Domain,
paramdomains: Dict[str, Domain],
n_samples: int = 100,
) -> None:
"""
Check that the icdf and logcdf functions of the distribution are
consistent for a sample of values in the domain of the
distribution.
"""
decimal = select_by_precision(float64=6, float32=3)
dist = create_dist_from_paramdomains(distribution, paramdomains)

value = pt.TensorType(dtype="float64", shape=[])("value")

dist_icdf = icdf(dist, value)
dist_icdf_fn = pytensor.function(list(inputvars(dist_icdf)), dist_icdf)

dist_logcdf = logcdf(dist, value)
dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf)

domains = paramdomains.copy()
domains["value"] = domain

for point in product(domains, n_samples=n_samples):
point = dict(point)
value = point.pop("value")

with pytensor.config.change_flags(mode=Mode("py")):
expected_value = value
computed_value = dist_icdf_fn(
**point, value=np.exp(dist_logcdf_fn(**point, value=value))
)
assert (
expected_value == computed_value
), f"expected_value = {expected_value}, computed_value = {computed_value}, {point}"


def assert_moment_is_expected(model, expected, check_finite_logp=True):
fn = make_initial_point_fn(
model=model,
Expand Down
5 changes: 5 additions & 0 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
check_icdf,
check_logcdf,
check_logp,
check_selfconsistency_continuous_icdf,
continuous_random_tester,
seeded_numpy_distribution_builder,
seeded_scipy_distribution_builder,
Expand Down Expand Up @@ -424,6 +425,10 @@ def scipy_log_cdf(value, a, b):
{"a": Rplus, "b": Rplus},
scipy_log_cdf,
)
check_selfconsistency_continuous_icdf(
pm.Kumaraswamy,
{"a": Rplusbig, "b": Rplusbig},
)

def test_exponential(self):
check_logp(
Expand Down
11 changes: 11 additions & 0 deletions tests/distributions/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
check_icdf,
check_logcdf,
check_logp,
check_selfconsistency_discrete_icdf,
check_selfconsistency_discrete_logcdf,
seeded_numpy_distribution_builder,
seeded_scipy_distribution_builder,
Expand Down Expand Up @@ -123,6 +124,11 @@ def test_discrete_unif(self):
lambda q, lower, upper: st.randint.ppf(q=q, low=lower, high=upper + 1),
skip_paramdomain_outside_edge_test=True,
)
check_selfconsistency_discrete_icdf(
pm.DiscreteUniform,
Rdunif,
{"lower": -Rplusdunif, "upper": Rplusdunif},
)
# Custom logp / logcdf check for invalid parameters
invalid_dist = pm.DiscreteUniform.dist(lower=1, upper=0)
with pytensor.config.change_flags(mode=Mode("py")):
Expand Down Expand Up @@ -156,6 +162,11 @@ def test_geometric(self):
{"p": Unit},
st.geom.ppf,
)
check_selfconsistency_discrete_icdf(
pm.Geometric,
Nat,
{"p": Unit},
)

def test_hypergeometric(self):
def modified_scipy_hypergeom_logcdf(value, N, k, n):
Expand Down

0 comments on commit a158f7e

Please sign in to comment.