Skip to content

Commit

Permalink
Add model debug helper
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 31, 2023
1 parent 4c64eb9 commit 6404805
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 1 deletion.
154 changes: 153 additions & 1 deletion pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import functools
import sys
import threading
import types
import warnings
Expand All @@ -24,6 +25,7 @@
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Expand All @@ -39,13 +41,15 @@
import pytensor.tensor as pt
import scipy.sparse as sps

from pytensor.compile import DeepCopyOp, get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant, Variable, graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.sharedvar import ScalarSharedVariable
from pytensor.tensor.var import TensorConstant, TensorVariable

Expand All @@ -61,6 +65,7 @@
)
from pymc.initial_point import make_initial_point_fn
from pymc.logprob.basic import joint_logp
from pymc.logprob.utils import ParameterValueError
from pymc.pytensorf import (
PointFunc,
SeedSequenceSeed,
Expand Down Expand Up @@ -1779,7 +1784,8 @@ def check_start_vals(self, start):
raise SamplingError(
"Initial evaluation of model at starting point failed!\n"
f"Starting values:\n{elem}\n\n"
f"Logp initial evaluation results:\n{initial_eval}"
f"Logp initial evaluation results:\n{initial_eval}\n"
"You can call `model.debug()` for more details."
)

def point_logps(self, point=None, round_vals=2):
Expand Down Expand Up @@ -1811,6 +1817,152 @@ def point_logps(self, point=None, round_vals=2):
)
}

def debug(
self,
point: Optional[Dict[str, np.ndarray]] = None,
fn: Literal["logp", "dlogp", "random"] = "logp",
verbose: bool = False,
):
"""Debug model function at point.
The method will evaluate the `fn` for each variable at a time.
When an evaluation fails or produces a non-finite value we print:
1. The graph of the parameters
2. The value of the parameters (if those can be evaluated)
3. The output of `fn` (if it can be evaluated)
This function should help to quickly narrow down invalid parametrizations.
Parameters
----------
point : Point
Point at which model function should be evaluated
fn : str, default "logp"
Function to be used for debugging. Can be one of [logp, dlogp, random].
verbose : bool, default False
Whether to show a more verbose PyTensor output when function cannot be evaluated
"""
print_ = functools.partial(print, file=sys.stdout)

def first_line(exc):
return exc.args[0].split("\n")[0]

def debug_parameters(rv):
if isinstance(rv.owner.op, RandomVariable):
inputs = rv.owner.inputs[3:]
else:
inputs = [inp for inp in rv.owner.inputs if not isinstance(inp.type, RandomType)]
rv_inputs = pytensor.function(
self.value_vars,
self.replace_rvs_by_values(inputs),
on_unused_input="ignore",
mode=get_mode(None).excluding("inplace", "fusion"),
)

print_(f"The variable {rv} has the following parameters:")
# done and used_ids are used to keep the same ids across distinct dprint calls
done = {}
used_ids = {}
for i, out in enumerate(rv_inputs.maker.fgraph.outputs):
print_(f"{i}: ", end=""),
# Don't print useless deepcopys
if out.owner and isinstance(out.owner.op, DeepCopyOp):
out = out.owner.inputs[0]
pytensor.dprint(out, print_type=True, done=done, used_ids=used_ids)

try:
print_("The parameters evaluate to:")
for i, rv_input_eval in enumerate(rv_inputs(**point)):
print_(f"{i}: {rv_input_eval}")
except Exception as exc:
print_(
f"The parameters of the variable {rv} cannot be evaluated: {first_line(exc)}"
)
if verbose:
print_(exc, "\n")

if fn not in ("logp", "dlogp", "random"):
raise ValueError(f"fn must be one of [logp, dlogp, random], got {fn}")

if point is None:
point = self.initial_point()
print_(f"point={point}\n")

rvs_to_check = list(self.basic_RVs)
if fn in ("logp", "dlogp"):
rvs_to_check += [self.replace_rvs_by_values(p) for p in self.potentials]

found_problem = False
for rv in rvs_to_check:
if fn == "logp":
rv_fn = pytensor.function(
self.value_vars, self.logp(vars=rv, sum=False)[0], on_unused_input="ignore"
)
elif fn == "dlogp":
rv_fn = pytensor.function(
self.value_vars, self.dlogp(vars=rv), on_unused_input="ignore"
)
else:
[rv_inputs_replaced] = replace_rvs_by_values(
[rv],
# Don't include itself, or the function will just the the value variable
rvs_to_values={
rv_key: value
for rv_key, value in self.rvs_to_values.items()
if rv_key is not rv
},
rvs_to_transforms=self.rvs_to_transforms,
)
rv_fn = pytensor.function(
self.value_vars, rv_inputs_replaced, on_unused_input="ignore"
)

try:
rv_fn_eval = rv_fn(**point)
except ParameterValueError as exc:
found_problem = True
debug_parameters(rv)
print_(
f"This does not respect one of the following constraints: {first_line(exc)}\n"
)
if verbose:
print_(exc)
except Exception as exc:
found_problem = True
debug_parameters(rv)
print_(
f"The variable {rv} {fn} method raised the following exception: {first_line(exc)}\n"
)
if verbose:
print_(exc)
else:
if not np.all(np.isfinite(rv_fn_eval)):
found_problem = True
debug_parameters(rv)
if fn == "random" or rv is self.potentials:
print_("This combination seems able to generate non-finite values")
else:
# Find which values are associated with non-finite evaluation
values = self.rvs_to_values[rv]
if rv in self.observed_RVs:
values = values.eval()
else:
values = point[values.name]

observed = " observed " if rv in self.observed_RVs else " "
print_(
f"Some of the{observed}values of variable {rv} are associated with a non-finite {fn}:"
)
mask = ~np.isfinite(rv_fn_eval)
for value, fn_eval in zip(values[mask], rv_fn_eval[mask]):
print_(f" value = {value} -> {fn} = {fn_eval}")
print_()

if not found_problem:
print_("No problems found")
elif not verbose:
print_("You can set `verbose=True` for more details")


# this is really disgusting, but it breaks a self-loop: I can't pass Model
# itself as context class init arg.
Expand Down
72 changes: 72 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import scipy.stats as st

from pytensor.graph import graph_inputs
from pytensor.raise_op import Assert, assert_op
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.sharedvar import ScalarSharedVariable
Expand Down Expand Up @@ -1553,3 +1554,74 @@ def test_tag_future_warning_model():
assert y_value.eval() == 5

assert isinstance(y_value.tag, _FutureWarningValidatingScratchpad)


class TestModelDebug:
@pytest.mark.parametrize("fn", ("logp", "dlogp", "random"))
def test_no_problems(self, fn, capfd):
with pm.Model() as m:
x = pm.Normal("x", [1, -1, 1])
m.debug(fn=fn)

out, _ = capfd.readouterr()
assert out == "point={'x': array([ 1., -1., 1.])}\n\nNo problems found\n"

@pytest.mark.parametrize("fn", ("logp", "dlogp", "random"))
def test_invalid_parameter(self, fn, capfd):
with pm.Model() as m:
x = pm.Normal("x", [1, -1, 1])
y = pm.HalfNormal("y", tau=x)
m.debug(fn=fn)

out, _ = capfd.readouterr()
if fn == "dlogp":
# var dlogp is 0 or 1 without a likelihood
assert "No problems found" in out
else:
assert "The parameters evaluate to:\n0: 0.0\n1: [ 1. -1. 1.]" in out
if fn == "logp":
assert "This does not respect one of the following constraints: sigma > 0" in out
else:
assert (
"The variable y random method raised the following exception: Domain error in arguments."
in out
)

@pytest.mark.parametrize("verbose", (True, False))
@pytest.mark.parametrize("fn", ("logp", "dlogp", "random"))
def test_invalid_parameter_cant_be_evaluated(self, fn, verbose, capfd):
with pm.Model() as m:
x = pm.Normal("x", [1, 1, 1])
sigma = Assert(msg="x > 0")(pm.math.abs(x), (x > 0).all())
y = pm.HalfNormal("y", sigma=sigma)
m.debug(point={"x": [-1, -1, -1], "y_log__": [0, 0, 0]}, fn=fn, verbose=verbose)

out, _ = capfd.readouterr()
assert "{'x': [-1, -1, -1], 'y_log__': [0, 0, 0]}" in out
assert "The parameters of the variable y cannot be evaluated: x > 0" in out
verbose_str = "Apply node that caused the error:" in out
assert verbose_str if verbose else not verbose_str

def test_invalid_value(self, capfd):
with pm.Model() as m:
x = pm.Normal("x", [1, -1, 1])
y = pm.HalfNormal("y", tau=pm.math.abs(x), initval=[-1, 1, -1], transform=None)
m.debug()

out, _ = capfd.readouterr()
assert "The parameters of the variable y evaluate to:\n0: array(0., dtype=float32)\n1: array([1., 1., 1.])]"
assert "Some of the values of variable y are associated with a non-finite logp" in out
assert "value = -1.0 -> logp = -inf" in out

def test_invalid_observed_value(self, capfd):
with pm.Model() as m:
theta = pm.Uniform("theta", lower=0, upper=1)
y = pm.Uniform("y", lower=0, upper=theta, observed=[0.49, 0.27, 0.53, 0.19])
m.debug()

out, _ = capfd.readouterr()
assert "The parameters of the variable y evaluate to:\n0: 0.0\n1: 0.5"
assert (
"Some of the observed values of variable y are associated with a non-finite logp" in out
)
assert "value = 0.53 -> logp = -inf" in out

0 comments on commit 6404805

Please sign in to comment.