Skip to content

Commit

Permalink
Allow opting out of model nesting
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 10, 2024
1 parent bc52c86 commit 17106cf
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 14 deletions.
18 changes: 12 additions & 6 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Literal,
Optional,
TypeVar,
Union,
cast,
overload,
)
Expand Down Expand Up @@ -468,6 +469,11 @@ class Model(WithMemoization, metaclass=ContextMeta):
# Variable will belong to root and second
z = pm.Normal("z", mu=y) # Variable wil be named "root::second::z"
# Set None for standalone model
with pm.Model(name="third", model=None) as third:
# Variable will belong to third only
w = pm.Normal("w") # Variable wil be named "third::w"
Set `check_bounds` to False for models with only continuous variables and default transformers
PyMC will remove the bounds check from the model logp which can speed up sampling
Expand All @@ -488,13 +494,13 @@ def __enter__(self: Self) -> Self: ...

def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: ...

def __new__(cls, *args, **kwargs):
def __new__(cls, *args, model: Union[UNSET, None, "Model"] = UNSET, **kwargs):
# resolves the parent instance
instance = super().__new__(cls)
if kwargs.get("model") is not None:
instance._parent = kwargs.get("model")
else:
if model is UNSET:
instance._parent = cls.get_context(error_if_none=False)
else:
instance._parent = model
return instance

@staticmethod
Expand All @@ -510,9 +516,9 @@ def __init__(
check_bounds=True,
*,
coords_mutable=None,
model=None,
model: Union[UNSET, None, "Model"] = UNSET,
):
del model # used in __new__
del model # used in __new__ to define the parent of this model
self.name = self._validate_name(name)
self.check_bounds = check_bounds

Expand Down
4 changes: 1 addition & 3 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,7 @@ def first_non_model_var(var):
else:
return var

model = Model()
if model.parent is not None:
raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context")
model = Model(model=None) # Do not inherit from any model in the context manager

_coords = getattr(fgraph, "_coords", {})
_dim_lengths = getattr(fgraph, "_dim_lengths", {})
Expand Down
16 changes: 14 additions & 2 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,20 @@ def test_docstring_example(self):
# Variable will belong to root and second
z = pm.Normal("z", mu=y) # Variable wil be named "root::second::z"

# Set None for standalone model
with pm.Model(name="third", model=None) as third:
# Variable will belong to third only
w = pm.Normal("w") # Variable wil be named "third::w"

assert x.name == "root::x"
assert y.name == "root::first::y"
assert z.name == "root::second::z"
assert w.name == "third::w"

assert set(root.basic_RVs) == {x, y, z}
assert set(first.basic_RVs) == {y}
assert set(second.basic_RVs) == {z}
assert set(third.basic_RVs) == {w}


class TestNested:
Expand Down Expand Up @@ -1106,11 +1113,16 @@ def test_model_parent_set_programmatically():
y = pm.Normal("y")

with model:
# Explict None opts out of model context
with pm.Model():
z_in = pm.Normal("z_in")

with pm.Model(model=None):
z = pm.Normal("z")
z_out = pm.Normal("z_out")

assert "y" in model.named_vars
assert "z" in model.named_vars
assert "z_in" not in model.named_vars
assert "z_out" not in model.named_vars


class TestModelContext:
Expand Down
11 changes: 8 additions & 3 deletions tests/model/test_fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,15 @@ def test_context_error():
with pm.Model() as m:
x = pm.Normal("x")

fg = fgraph_from_model(m)
fg, _ = fgraph_from_model(m)

with pytest.raises(RuntimeError, match="cannot be called inside a PyMC model context"):
model_from_fgraph(fg)
new_m = model_from_fgraph(fg)
new_x = new_m["x"]

assert new_m.parent is None
assert x != new_x
assert m.named_vars == {"x": x}
assert new_m.named_vars == {"x": new_x}


def test_sub_model_error():
Expand Down

0 comments on commit 17106cf

Please sign in to comment.