From 640480545da026827a3b610a49e089958c7dd59a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 30 Mar 2023 15:54:09 +0200 Subject: [PATCH] Add model debug helper --- pymc/model.py | 154 +++++++++++++++++++++++++++++++++++++++++++- tests/test_model.py | 72 +++++++++++++++++++++ 2 files changed, 225 insertions(+), 1 deletion(-) diff --git a/pymc/model.py b/pymc/model.py index 9a3dd28aa5..6b7b5d142f 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import sys import threading import types import warnings @@ -24,6 +25,7 @@ Callable, Dict, List, + Literal, Optional, Sequence, Tuple, @@ -39,6 +41,7 @@ import pytensor.tensor as pt import scipy.sparse as sps +from pytensor.compile import DeepCopyOp, get_mode from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant, Variable, graph_inputs from pytensor.graph.fg import FunctionGraph @@ -46,6 +49,7 @@ from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.rewriting import local_subtensor_rv_lift +from pytensor.tensor.random.type import RandomType from pytensor.tensor.sharedvar import ScalarSharedVariable from pytensor.tensor.var import TensorConstant, TensorVariable @@ -61,6 +65,7 @@ ) from pymc.initial_point import make_initial_point_fn from pymc.logprob.basic import joint_logp +from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import ( PointFunc, SeedSequenceSeed, @@ -1779,7 +1784,8 @@ def check_start_vals(self, start): raise SamplingError( "Initial evaluation of model at starting point failed!\n" f"Starting values:\n{elem}\n\n" - f"Logp initial evaluation results:\n{initial_eval}" + f"Logp initial evaluation results:\n{initial_eval}\n" + "You can call `model.debug()` for more details." ) def point_logps(self, point=None, round_vals=2): @@ -1811,6 +1817,152 @@ def point_logps(self, point=None, round_vals=2): ) } + def debug( + self, + point: Optional[Dict[str, np.ndarray]] = None, + fn: Literal["logp", "dlogp", "random"] = "logp", + verbose: bool = False, + ): + """Debug model function at point. + + The method will evaluate the `fn` for each variable at a time. + When an evaluation fails or produces a non-finite value we print: + 1. The graph of the parameters + 2. The value of the parameters (if those can be evaluated) + 3. The output of `fn` (if it can be evaluated) + + This function should help to quickly narrow down invalid parametrizations. + + Parameters + ---------- + point : Point + Point at which model function should be evaluated + fn : str, default "logp" + Function to be used for debugging. Can be one of [logp, dlogp, random]. + verbose : bool, default False + Whether to show a more verbose PyTensor output when function cannot be evaluated + """ + print_ = functools.partial(print, file=sys.stdout) + + def first_line(exc): + return exc.args[0].split("\n")[0] + + def debug_parameters(rv): + if isinstance(rv.owner.op, RandomVariable): + inputs = rv.owner.inputs[3:] + else: + inputs = [inp for inp in rv.owner.inputs if not isinstance(inp.type, RandomType)] + rv_inputs = pytensor.function( + self.value_vars, + self.replace_rvs_by_values(inputs), + on_unused_input="ignore", + mode=get_mode(None).excluding("inplace", "fusion"), + ) + + print_(f"The variable {rv} has the following parameters:") + # done and used_ids are used to keep the same ids across distinct dprint calls + done = {} + used_ids = {} + for i, out in enumerate(rv_inputs.maker.fgraph.outputs): + print_(f"{i}: ", end=""), + # Don't print useless deepcopys + if out.owner and isinstance(out.owner.op, DeepCopyOp): + out = out.owner.inputs[0] + pytensor.dprint(out, print_type=True, done=done, used_ids=used_ids) + + try: + print_("The parameters evaluate to:") + for i, rv_input_eval in enumerate(rv_inputs(**point)): + print_(f"{i}: {rv_input_eval}") + except Exception as exc: + print_( + f"The parameters of the variable {rv} cannot be evaluated: {first_line(exc)}" + ) + if verbose: + print_(exc, "\n") + + if fn not in ("logp", "dlogp", "random"): + raise ValueError(f"fn must be one of [logp, dlogp, random], got {fn}") + + if point is None: + point = self.initial_point() + print_(f"point={point}\n") + + rvs_to_check = list(self.basic_RVs) + if fn in ("logp", "dlogp"): + rvs_to_check += [self.replace_rvs_by_values(p) for p in self.potentials] + + found_problem = False + for rv in rvs_to_check: + if fn == "logp": + rv_fn = pytensor.function( + self.value_vars, self.logp(vars=rv, sum=False)[0], on_unused_input="ignore" + ) + elif fn == "dlogp": + rv_fn = pytensor.function( + self.value_vars, self.dlogp(vars=rv), on_unused_input="ignore" + ) + else: + [rv_inputs_replaced] = replace_rvs_by_values( + [rv], + # Don't include itself, or the function will just the the value variable + rvs_to_values={ + rv_key: value + for rv_key, value in self.rvs_to_values.items() + if rv_key is not rv + }, + rvs_to_transforms=self.rvs_to_transforms, + ) + rv_fn = pytensor.function( + self.value_vars, rv_inputs_replaced, on_unused_input="ignore" + ) + + try: + rv_fn_eval = rv_fn(**point) + except ParameterValueError as exc: + found_problem = True + debug_parameters(rv) + print_( + f"This does not respect one of the following constraints: {first_line(exc)}\n" + ) + if verbose: + print_(exc) + except Exception as exc: + found_problem = True + debug_parameters(rv) + print_( + f"The variable {rv} {fn} method raised the following exception: {first_line(exc)}\n" + ) + if verbose: + print_(exc) + else: + if not np.all(np.isfinite(rv_fn_eval)): + found_problem = True + debug_parameters(rv) + if fn == "random" or rv is self.potentials: + print_("This combination seems able to generate non-finite values") + else: + # Find which values are associated with non-finite evaluation + values = self.rvs_to_values[rv] + if rv in self.observed_RVs: + values = values.eval() + else: + values = point[values.name] + + observed = " observed " if rv in self.observed_RVs else " " + print_( + f"Some of the{observed}values of variable {rv} are associated with a non-finite {fn}:" + ) + mask = ~np.isfinite(rv_fn_eval) + for value, fn_eval in zip(values[mask], rv_fn_eval[mask]): + print_(f" value = {value} -> {fn} = {fn_eval}") + print_() + + if not found_problem: + print_("No problems found") + elif not verbose: + print_("You can set `verbose=True` for more details") + # this is really disgusting, but it breaks a self-loop: I can't pass Model # itself as context class init arg. diff --git a/tests/test_model.py b/tests/test_model.py index a8c7dc2c54..c6b176af96 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -30,6 +30,7 @@ import scipy.stats as st from pytensor.graph import graph_inputs +from pytensor.raise_op import Assert, assert_op from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.sharedvar import ScalarSharedVariable @@ -1553,3 +1554,74 @@ def test_tag_future_warning_model(): assert y_value.eval() == 5 assert isinstance(y_value.tag, _FutureWarningValidatingScratchpad) + + +class TestModelDebug: + @pytest.mark.parametrize("fn", ("logp", "dlogp", "random")) + def test_no_problems(self, fn, capfd): + with pm.Model() as m: + x = pm.Normal("x", [1, -1, 1]) + m.debug(fn=fn) + + out, _ = capfd.readouterr() + assert out == "point={'x': array([ 1., -1., 1.])}\n\nNo problems found\n" + + @pytest.mark.parametrize("fn", ("logp", "dlogp", "random")) + def test_invalid_parameter(self, fn, capfd): + with pm.Model() as m: + x = pm.Normal("x", [1, -1, 1]) + y = pm.HalfNormal("y", tau=x) + m.debug(fn=fn) + + out, _ = capfd.readouterr() + if fn == "dlogp": + # var dlogp is 0 or 1 without a likelihood + assert "No problems found" in out + else: + assert "The parameters evaluate to:\n0: 0.0\n1: [ 1. -1. 1.]" in out + if fn == "logp": + assert "This does not respect one of the following constraints: sigma > 0" in out + else: + assert ( + "The variable y random method raised the following exception: Domain error in arguments." + in out + ) + + @pytest.mark.parametrize("verbose", (True, False)) + @pytest.mark.parametrize("fn", ("logp", "dlogp", "random")) + def test_invalid_parameter_cant_be_evaluated(self, fn, verbose, capfd): + with pm.Model() as m: + x = pm.Normal("x", [1, 1, 1]) + sigma = Assert(msg="x > 0")(pm.math.abs(x), (x > 0).all()) + y = pm.HalfNormal("y", sigma=sigma) + m.debug(point={"x": [-1, -1, -1], "y_log__": [0, 0, 0]}, fn=fn, verbose=verbose) + + out, _ = capfd.readouterr() + assert "{'x': [-1, -1, -1], 'y_log__': [0, 0, 0]}" in out + assert "The parameters of the variable y cannot be evaluated: x > 0" in out + verbose_str = "Apply node that caused the error:" in out + assert verbose_str if verbose else not verbose_str + + def test_invalid_value(self, capfd): + with pm.Model() as m: + x = pm.Normal("x", [1, -1, 1]) + y = pm.HalfNormal("y", tau=pm.math.abs(x), initval=[-1, 1, -1], transform=None) + m.debug() + + out, _ = capfd.readouterr() + assert "The parameters of the variable y evaluate to:\n0: array(0., dtype=float32)\n1: array([1., 1., 1.])]" + assert "Some of the values of variable y are associated with a non-finite logp" in out + assert "value = -1.0 -> logp = -inf" in out + + def test_invalid_observed_value(self, capfd): + with pm.Model() as m: + theta = pm.Uniform("theta", lower=0, upper=1) + y = pm.Uniform("y", lower=0, upper=theta, observed=[0.49, 0.27, 0.53, 0.19]) + m.debug() + + out, _ = capfd.readouterr() + assert "The parameters of the variable y evaluate to:\n0: 0.0\n1: 0.5" + assert ( + "Some of the observed values of variable y are associated with a non-finite logp" in out + ) + assert "value = 0.53 -> logp = -inf" in out