From 269da5c03c771a54fc0478704a4a3bb53d4aa9c8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 10 Jun 2024 18:14:02 +0200 Subject: [PATCH] Allow opting out of model nesting --- pymc/model/core.py | 20 +++++++++----- pymc/model/fgraph.py | 4 +-- pymc/model/transform/basic.py | 2 +- pymc/model/transform/conditioning.py | 2 +- pymc/sampling/deterministic.py | 2 +- pymc/stats/log_density.py | 39 +++++++++++----------------- tests/model/test_core.py | 17 ++++++++++-- tests/model/test_fgraph.py | 11 +++++--- 8 files changed, 55 insertions(+), 42 deletions(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index 7aec544d79f..475ff42b5f7 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -25,6 +25,7 @@ Literal, Optional, TypeVar, + Union, cast, overload, ) @@ -441,7 +442,7 @@ class Model(WithMemoization, metaclass=ContextMeta): coords = { "feature", ["A", "B", "C"], - "trial", [1, 2, 3, 4, 5], + "trial", [1, 2, 3, 4, 5], } with pm.Model(coords=coords) as model: @@ -476,6 +477,11 @@ class Model(WithMemoization, metaclass=ContextMeta): # Variable will belong to root and second z = pm.Normal("z", mu=y) # Variable wil be named "root::second::z" + # Set None for standalone model + with pm.Model(name="third", model=None) as third: + # Variable will belong to third only + w = pm.Normal("w") # Variable wil be named "third::w" + Set `check_bounds` to False for models with only continuous variables and default transformers PyMC will remove the bounds check from the model logp which can speed up sampling @@ -497,13 +503,13 @@ def __enter__(self: Self) -> Self: ... def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: ... - def __new__(cls, *args, **kwargs): + def __new__(cls, *args, model: Union[Literal[UNSET], None, "Model"] = UNSET, **kwargs): # resolves the parent instance instance = super().__new__(cls) - if kwargs.get("model") is not None: - instance._parent = kwargs.get("model") - else: + if model is UNSET: instance._parent = cls.get_context(error_if_none=False) + else: + instance._parent = model return instance @staticmethod @@ -519,9 +525,9 @@ def __init__( check_bounds=True, *, coords_mutable=None, - model=None, + model: Union[Literal[UNSET], None, "Model"] = UNSET, ): - del model # used in __new__ + del model # used in __new__ to define the parent of this model self.name = self._validate_name(name) self.check_bounds = check_bounds diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py index ce15c40760f..b1d67fd07b0 100644 --- a/pymc/model/fgraph.py +++ b/pymc/model/fgraph.py @@ -299,9 +299,7 @@ def first_non_model_var(var): else: return var - model = Model() - if model.parent is not None: - raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context") + model = Model(model=None) # Do not inherit from any model in the context manager _coords = getattr(fgraph, "_coords", {}) _dim_lengths = getattr(fgraph, "_dim_lengths", {}) diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index aff6042de5a..76556ae08ab 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -16,7 +16,7 @@ from pytensor import Variable from pytensor.graph import ancestors -from pymc import Model +from pymc.model.core import Model from pymc.model.fgraph import ( ModelObservedRV, ModelVar, diff --git a/pymc/model/transform/conditioning.py b/pymc/model/transform/conditioning.py index 0979964eecf..531ec2bd6b9 100644 --- a/pymc/model/transform/conditioning.py +++ b/pymc/model/transform/conditioning.py @@ -19,9 +19,9 @@ from pytensor.graph import ancestors from pytensor.tensor import TensorVariable -from pymc import Model from pymc.logprob.transforms import Transform from pymc.logprob.utils import rvs_in_graph +from pymc.model.core import Model from pymc.model.fgraph import ( ModelDeterministic, ModelFreeRV, diff --git a/pymc/sampling/deterministic.py b/pymc/sampling/deterministic.py index b0b04f38ec7..3d8398c3a7e 100644 --- a/pymc/sampling/deterministic.py +++ b/pymc/sampling/deterministic.py @@ -83,7 +83,7 @@ def compute_deterministics( model = modelcontext(model) if var_names is None: - deterministics = model.deterministics + deterministics = list(model.deterministics) var_names = [det.name for det in deterministics] else: deterministics = [model[var_name] for var_name in var_names] diff --git a/pymc/stats/log_density.py b/pymc/stats/log_density.py index 5b6406d02b1..a26f8aa60df 100644 --- a/pymc/stats/log_density.py +++ b/pymc/stats/log_density.py @@ -25,6 +25,8 @@ __all__ = ("compute_log_likelihood", "compute_log_prior") +from pymc.model.transform.conditioning import remove_value_transforms + def compute_log_likelihood( idata: InferenceData, @@ -126,46 +128,35 @@ def compute_log_density( if kind not in ("likelihood", "prior"): raise ValueError("kind must be either 'likelihood' or 'prior'") + # We need to disable transforms, because the InferenceData only keeps the untransformed values + umodel = remove_value_transforms(model) + if kind == "likelihood": - target_rvs = model.observed_RVs + target_rvs = list(umodel.observed_RVs) target_str = "observed_RVs" else: - target_rvs = model.free_RVs + target_rvs = list(umodel.free_RVs) target_str = "free_RVs" if var_names is None: vars = target_rvs var_names = tuple(rv.name for rv in vars) else: - vars = [model.named_vars[name] for name in var_names] + vars = [umodel.named_vars[name] for name in var_names] if not set(vars).issubset(target_rvs): raise ValueError(f"var_names must refer to {target_str} in the model. Got: {var_names}") - # We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values - try: - original_rvs_to_values = model.rvs_to_values - original_rvs_to_transforms = model.rvs_to_transforms - - model.rvs_to_values = { - rv: rv.clone() if rv not in model.observed_RVs else value - for rv, value in model.rvs_to_values.items() - } - model.rvs_to_transforms = {rv: None for rv in model.basic_RVs} - - elemwise_logdens_fn = model.compile_fn( - inputs=model.value_vars, - outs=model.logp(vars=vars, sum=False), - on_unused_input="ignore", - ) - finally: - model.rvs_to_values = original_rvs_to_values - model.rvs_to_transforms = original_rvs_to_transforms + elemwise_logdens_fn = umodel.compile_fn( + inputs=umodel.value_vars, + outs=umodel.logp(vars=vars, sum=False), + on_unused_input="ignore", + ) - coords, dims = coords_and_dims_for_inferencedata(model) + coords, dims = coords_and_dims_for_inferencedata(umodel) logdens_dataset = apply_function_over_dataset( elemwise_logdens_fn, - posterior[[rv.name for rv in model.free_RVs]], + posterior[[rv.name for rv in umodel.free_RVs]], output_var_names=var_names, sample_dims=sample_dims, dims=dims, diff --git a/tests/model/test_core.py b/tests/model/test_core.py index 484ac76d9c6..e43a86cebfe 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -143,13 +143,20 @@ def test_docstring_example(self): # Variable will belong to root and second z = pm.Normal("z", mu=y) # Variable wil be named "root::second::z" + # Set None for standalone model + with pm.Model(name="third", model=None) as third: + # Variable will belong to third only + w = pm.Normal("w") # Variable wil be named "third::w" + assert x.name == "root::x" assert y.name == "root::first::y" assert z.name == "root::second::z" + assert w.name == "third::w" assert set(root.basic_RVs) == {x, y, z} assert set(first.basic_RVs) == {y} assert set(second.basic_RVs) == {z} + assert set(third.basic_RVs) == {w} class TestNested: @@ -1106,11 +1113,17 @@ def test_model_parent_set_programmatically(): y = pm.Normal("y") with model: + # Default inherits from model + with pm.Model(): + z_in = pm.Normal("z_in") + + # Explict None opts out of model context with pm.Model(model=None): - z = pm.Normal("z") + z_out = pm.Normal("z_out") assert "y" in model.named_vars - assert "z" in model.named_vars + assert "z_in" in model.named_vars + assert "z_out" not in model.named_vars class TestModelContext: diff --git a/tests/model/test_fgraph.py b/tests/model/test_fgraph.py index 7a57bfc16a4..a964f1faf6d 100644 --- a/tests/model/test_fgraph.py +++ b/tests/model/test_fgraph.py @@ -267,10 +267,15 @@ def test_context_error(): with pm.Model() as m: x = pm.Normal("x") - fg = fgraph_from_model(m) + fg, _ = fgraph_from_model(m) - with pytest.raises(RuntimeError, match="cannot be called inside a PyMC model context"): - model_from_fgraph(fg) + new_m = model_from_fgraph(fg) + new_x = new_m["x"] + + assert new_m.parent is None + assert x != new_x + assert m.named_vars == {"x": x} + assert new_m.named_vars == {"x": new_x} def test_sub_model_error():