diff --git a/pymc/testing.py b/pymc/testing.py index 4122f0f574..a61ae3dbbe 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -699,6 +699,7 @@ def check_selfconsistency_discrete_icdf( distribution: Distribution, domain: Domain, paramdomains: Dict[str, Domain], + decimal: Optional[int] = None, n_samples: int = 100, ) -> None: """ @@ -706,6 +707,13 @@ def check_selfconsistency_discrete_icdf( consistent for a sample of values in the domain of the distribution. """ + + def ftrunc(values, decimal=0): + return np.trunc(values * 10**decimal) / (10**decimal) + + if decimal is None: + decimal = select_by_precision(float64=6, float32=3) + dist = create_dist_from_paramdomains(distribution, paramdomains) value = pt.TensorType(dtype="float64", shape=[])("value") @@ -726,7 +734,7 @@ def check_selfconsistency_discrete_icdf( with pytensor.config.change_flags(mode=Mode("py")): expected_value = value computed_value = dist_icdf_fn( - **point, value=np.exp(dist_logcdf_fn(**point, value=value)) + **point, value=ftrunc(np.exp(dist_logcdf_fn(**point, value=value)), decimal=decimal) ) npt.assert_almost_equal( expected_value,