Skip to content

Commit

Permalink
updated copy method docs and simplified TestModelCopy tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Dekermanjian committed Sep 30, 2024
1 parent 90419cb commit 07106ec
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 26 deletions.
5 changes: 3 additions & 2 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,8 +1579,9 @@ def __deepcopy__(self, _):

def copy(self):
"""
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph.
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.
Clone the model
To access variables in the cloned model use `cloned_model["var_name"]`.
Examples
--------
Expand Down
42 changes: 18 additions & 24 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,17 +1765,13 @@ def test_graphviz_call_function(self, var_names, filenames) -> None:


class TestModelCopy:
@staticmethod
def simple_model() -> pm.Model:
@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_copy_model(self, copy_method) -> None:
with pm.Model() as simple_model:
error = pm.HalfNormal("error", 0.5)
alpha = pm.Normal("alpha", 0, 1)
pm.Normal("y", alpha, error)
return simple_model

@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_copy_model(self, copy_method) -> None:
simple_model = self.simple_model()
copy_simple_model = copy_method(simple_model)

with simple_model:
Expand All @@ -1786,15 +1782,24 @@ def test_copy_model(self, copy_method) -> None:
samples=1, random_seed=42
)

simple_model_prior_predictive_mean = simple_model_prior_predictive["prior"]["y"].mean(
("chain", "draw")
)
copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive["prior"][
simple_model_prior_predictive_val = simple_model_prior_predictive["prior"]["y"].values
copy_simple_model_prior_predictive_val = copy_simple_model_prior_predictive["prior"][
"y"
].mean(("chain", "draw"))
].values

assert np.isclose(
simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean
assert simple_model_prior_predictive_val == copy_simple_model_prior_predictive_val

with copy_simple_model:
z = pm.Deterministic("z", copy_simple_model["alpha"] + 1)
copy_simple_model_prior_predictive = pm.sample_prior_predictive(
samples=1, random_seed=42
)

assert "z" in copy_simple_model.named_vars
assert "z" not in simple_model.named_vars
assert (
copy_simple_model_prior_predictive["prior"]["z"].values
== 1 + copy_simple_model_prior_predictive["prior"]["alpha"].values
)

@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
Expand All @@ -1811,14 +1816,3 @@ def test_guassian_process_copy_failure(self, copy_method) -> None:
match="Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883",
):
copy_method(gaussian_process_model)

@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_adding_deterministics_to_clone(self, copy_method) -> None:
simple_model = self.simple_model()
clone_model = copy_method(simple_model)

with clone_model:
z = pm.Deterministic("z", clone_model["alpha"] + 1)

assert "z" in clone_model.named_vars
assert "z" not in simple_model.named_vars

0 comments on commit 07106ec

Please sign in to comment.