From 5c833fda98d89ed68323284b1377d77b81397535 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 28 Aug 2023 12:17:14 +0200 Subject: [PATCH] Graduate fgraph functionality from pymc-experimental --- .github/workflows/tests.yml | 9 +- docs/source/api.rst | 4 +- docs/source/api/{model.rst => model/core.rst} | 2 +- docs/source/api/model/fgraph.rst | 10 + .../api/model/transform/conditioning.rst | 18 + .../contributing/running_the_test_suite.md | 6 +- pymc/__init__.py | 3 +- pymc/distributions/distribution.py | 2 +- pymc/model/__init__.py | 15 + pymc/{model.py => model/core.py} | 0 pymc/model/fgraph.py | 396 +++++++++++++++++ pymc/model/transform/__init__.py | 13 + pymc/model/transform/basic.py | 59 +++ pymc/model/transform/conditioning.py | 402 ++++++++++++++++++ pymc/pytensorf.py | 45 +- scripts/run_mypy.py | 5 +- tests/model/__init__.py | 13 + tests/{test_model.py => model/test_core.py} | 5 +- tests/model/test_fgraph.py | 353 +++++++++++++++ tests/model/transform/__init__.py | 13 + tests/model/transform/test_basic.py | 32 ++ tests/model/transform/test_conditioning.py | 310 ++++++++++++++ 22 files changed, 1700 insertions(+), 15 deletions(-) rename docs/source/api/{model.rst => model/core.rst} (89%) create mode 100644 docs/source/api/model/fgraph.rst create mode 100644 docs/source/api/model/transform/conditioning.rst create mode 100644 pymc/model/__init__.py rename pymc/{model.py => model/core.py} (100%) create mode 100644 pymc/model/fgraph.py create mode 100644 pymc/model/transform/__init__.py create mode 100644 pymc/model/transform/basic.py create mode 100644 pymc/model/transform/conditioning.py create mode 100644 tests/model/__init__.py rename tests/{test_model.py => model/test_core.py} (99%) create mode 100644 tests/model/test_fgraph.py create mode 100644 tests/model/transform/__init__.py create mode 100644 tests/model/transform/test_basic.py create mode 100644 tests/model/transform/test_conditioning.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8f932e0747e..00976e52ecf 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -93,7 +93,10 @@ jobs: tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py - tests/test_model.py + tests/model/test_core.py + tests/model/test_fgraph.py + tests/model/transform/test_basic.py + tests/model/transform/test_conditioning.py tests/test_model_graph.py tests/ode/test_ode.py tests/ode/test_utils.py @@ -187,7 +190,7 @@ jobs: python-version: ["3.9"] test-subset: - tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py - - tests/test_model.py tests/sampling/test_mcmc.py + - tests/model/test_core.py tests/sampling/test_mcmc.py - tests/gp/test_cov.py tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py tests/ode/test_ode.py tests/ode/test_utils.py tests/smc/test_smc.py tests/sampling/test_parallel.py - tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py @@ -266,7 +269,7 @@ jobs: tests/sampling/test_parallel.py tests/test_data.py tests/variational/test_minibatch_rv.py - tests/test_model.py + tests/model/test_core.py - | tests/sampling/test_mcmc.py diff --git a/docs/source/api.rst b/docs/source/api.rst index 4aa717a2dcb..4f62539da60 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,7 +9,9 @@ API api/distributions api/gp - api/model + api/model/core + api/model/fgraph + api/model/transform/conditioning api/samplers api/vi api/smc diff --git a/docs/source/api/model.rst b/docs/source/api/model/core.rst similarity index 89% rename from docs/source/api/model.rst rename to docs/source/api/model/core.rst index 0c288978aae..47220162c98 100644 --- a/docs/source/api/model.rst +++ b/docs/source/api/model/core.rst @@ -4,7 +4,7 @@ Model Model creation and inspection ----------------------------- -.. currentmodule:: pymc +.. currentmodule:: pymc.model.core .. autosummary:: :toctree: generated/ diff --git a/docs/source/api/model/fgraph.rst b/docs/source/api/model/fgraph.rst new file mode 100644 index 00000000000..8a9aed2964b --- /dev/null +++ b/docs/source/api/model/fgraph.rst @@ -0,0 +1,10 @@ +FunctionGraph +------------- + +.. currentmodule:: pymc.model.fgraph +.. autosummary:: + :toctree: generated/ + + clone_model + fgraph_from_model + model_from_fgraph diff --git a/docs/source/api/model/transform/conditioning.rst b/docs/source/api/model/transform/conditioning.rst new file mode 100644 index 00000000000..004822c82cc --- /dev/null +++ b/docs/source/api/model/transform/conditioning.rst @@ -0,0 +1,18 @@ +Model Conditioning +------------------ + +.. currentmodule:: pymc.model.transform.conditioning +.. autosummary:: + :toctree: generated/ + + do + observe + +Others +------ + +.. autosummary:: + :toctree: generated/ + + change_value_transforms + remove_value_transforms diff --git a/docs/source/contributing/running_the_test_suite.md b/docs/source/contributing/running_the_test_suite.md index 2e50d7678ff..e1b16ed9dc4 100644 --- a/docs/source/contributing/running_the_test_suite.md +++ b/docs/source/contributing/running_the_test_suite.md @@ -19,7 +19,7 @@ Therefore, we recommend to run just specific tests that target the parts of the To run all tests from a single file: ```bash -pytest -v tests/test_model.py +pytest -v tests/model/test_core.py ``` ```{tip} @@ -28,10 +28,10 @@ The `-v` flag is short-hand for `--verbose` and prints the names of the test cas Often, you'll want to focus on just a few test cases first. By using the `-k` flag, you can filter for test cases that match a certain pattern. -For example, the following command runs all test cases from `test_model.py` that have "coord" in their name: +For example, the following command runs all test cases from `test_core.py` that have "coord" in their name: ```bash -pytest -v tests/test_model.py -k coord +pytest -v tests/model/test_core.py -k coord ``` diff --git a/pymc/__init__.py b/pymc/__init__.py index 2fca97dca15..b6c38b07dc7 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -63,7 +63,8 @@ def __set_compiler_flags(): logsumexp, probit, ) -from pymc.model import * +from pymc.model.core import * +from pymc.model.transform.conditioning import do, observe from pymc.model_graph import model_to_graphviz, model_to_networkx from pymc.plots import * from pymc.printing import * diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 8056fbd6bb3..697a3b9178f 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -53,7 +53,7 @@ from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob from pymc.logprob.basic import logp from pymc.logprob.rewriting import logprob_rewrites_db -from pymc.model import new_or_existing_block_model_access +from pymc.model.core import new_or_existing_block_model_access from pymc.printing import str_for_dist from pymc.pytensorf import ( collect_default_updates, diff --git a/pymc/model/__init__.py b/pymc/model/__init__.py new file mode 100644 index 00000000000..bd70de3e276 --- /dev/null +++ b/pymc/model/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pymc.model.core import * +from pymc.model.core import ValueGradFunction diff --git a/pymc/model.py b/pymc/model/core.py similarity index 100% rename from pymc/model.py rename to pymc/model/core.py diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py new file mode 100644 index 00000000000..4533ef4ef71 --- /dev/null +++ b/pymc/model/fgraph.py @@ -0,0 +1,396 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import copy +from typing import Dict, Optional, Tuple + +import pytensor + +from pytensor import Variable, shared +from pytensor.compile import SharedVariable +from pytensor.graph import Apply, FunctionGraph, Op, node_rewriter +from pytensor.graph.rewriting.basic import out2in +from pytensor.scalar import Identity +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.sharedvar import ScalarSharedVariable + +from pymc.logprob.transforms import RVTransform +from pymc.model.core import Model +from pymc.pytensorf import StringType, find_rng_nodes, toposort_replace + + +class ModelVar(Op): + """A dummy Op that describes the purpose of a Model variable and contains + meta-information as additional inputs (value and dims). + """ + + def make_node(self, rv, *dims): + assert isinstance(rv, Variable) + dims = self._parse_dims(rv, *dims) + return Apply(self, [rv, *dims], [rv.type(name=rv.name)]) + + def _parse_dims(self, rv, *dims): + if dims: + dims = [pytensor.as_symbolic(dim) for dim in dims] + assert all(isinstance(dim.type, StringType) for dim in dims) + assert len(dims) == rv.type.ndim + return dims + + def infer_shape(self, fgraph, node, inputs_shape): + return [inputs_shape[0]] + + def do_constant_folding(self, fgraph, node): + return False + + def perform(self, *args, **kwargs): + raise RuntimeError("ModelVars should never be in a final graph!") + + +class ModelValuedVar(ModelVar): + __props__ = ("transform",) + + def __init__(self, transform: Optional[RVTransform] = None): + if transform is not None and not isinstance(transform, RVTransform): + raise TypeError(f"transform must be None or RVTransform type, got {type(transform)}") + self.transform = transform + super().__init__() + + def make_node(self, rv, value, *dims): + assert isinstance(rv, Variable) + dims = self._parse_dims(rv, *dims) + if value is not None: + assert isinstance(value, Variable) + assert rv.type.in_same_class(value.type) + return Apply(self, [rv, value, *dims], [rv.type(name=rv.name)]) + + +class ModelFreeRV(ModelValuedVar): + pass + + +class ModelObservedRV(ModelValuedVar): + pass + + +class ModelPotential(ModelVar): + pass + + +class ModelDeterministic(ModelVar): + pass + + +class ModelNamed(ModelVar): + pass + + +def model_free_rv(rv, value, transform, *dims): + return ModelFreeRV(transform=transform)(rv, value, *dims) + + +model_observed_rv = ModelObservedRV() +model_potential = ModelPotential() +model_deterministic = ModelDeterministic() +model_named = ModelNamed() + + +@node_rewriter([Elemwise]) +def local_remove_identity(fgraph, node): + if isinstance(node.op.scalar_op, Identity): + return [node.inputs[0]] + + +remove_identity_rewrite = out2in(local_remove_identity) + + +def fgraph_from_model( + model: Model, inlined_views=False +) -> Tuple[FunctionGraph, Dict[Variable, Variable]]: + """Convert Model to FunctionGraph. + + See: model_from_fgraph + + Parameters + ---------- + model: PyMC model + inlined_views: bool, default False + Whether "view" variables (Deterministics and Data) should be inlined among RVs in the fgraph, + or show up as separate branches. + + Returns + ------- + fgraph: FunctionGraph + FunctionGraph that includes a copy of model variables, wrapped in dummy `ModelVar` Ops. + It should be possible to reconstruct a valid PyMC model using `model_from_fgraph`. + + memo: Dict + A dictionary mapping original model variables to the equivalent nodes in the fgraph. + """ + + if any(v is not None for v in model.rvs_to_initial_values.values()): + raise NotImplementedError("Cannot convert models with non-default initial_values") + + if model.parent is not None: + raise ValueError( + "Nested sub-models cannot be converted to fgraph. Convert the parent model instead" + ) + + # Collect PyTensor variables + rvs_to_values = model.rvs_to_values + rvs = list(rvs_to_values.keys()) + free_rvs = model.free_RVs + observed_rvs = model.observed_RVs + potentials = model.potentials + named_vars = model.named_vars.values() + # We copy Deterministics (Identity Op) so that they don't show in between "main" variables + # We later remove these Identity Ops when we have a Deterministic ModelVar Op as a separator + old_deterministics = model.deterministics + deterministics = [det if inlined_views else det.copy(det.name) for det in old_deterministics] + # Value variables (we also have to decide whether to inline named ones) + old_value_vars = list(rvs_to_values.values()) + unnamed_value_vars = [val for val in old_value_vars if val not in named_vars] + named_value_vars = [ + val if inlined_views else val.copy(val.name) for val in old_value_vars if val in named_vars + ] + value_vars = old_value_vars.copy() + if inlined_views: + # In this case we want to use the named_value_vars as the value_vars in RVs + for named_val in named_value_vars: + idx = value_vars.index(named_val) + value_vars[idx] = named_val + # Other variables that are in named_vars but are not any of the categories above + # E.g., MutableData, ConstantData, _dim_lengths + # We use the same trick as deterministics! + accounted_for = set(free_rvs + observed_rvs + potentials + old_deterministics + old_value_vars) + other_named_vars = [ + var if inlined_views else var.copy(var.name) + for var in named_vars + if var not in accounted_for + ] + + model_vars = ( + rvs + potentials + deterministics + other_named_vars + named_value_vars + unnamed_value_vars + ) + + memo = {} + + # Replace the following shared variables in the model: + # 1. RNGs + # 2. MutableData (could increase memory usage significantly) + # 3. Mutable coords dim lengths + shared_vars_to_copy = find_rng_nodes(model_vars) + shared_vars_to_copy += [v for v in model.dim_lengths.values() if isinstance(v, SharedVariable)] + shared_vars_to_copy += [v for v in model.named_vars.values() if isinstance(v, SharedVariable)] + for var in shared_vars_to_copy: + # FIXME: ScalarSharedVariables are converted to 0d numpy arrays internally, + # so calling shared(shared(5).get_value()) returns a different type: TensorSharedVariables! + # Furthermore, PyMC silently ignores mutable dim changes that are SharedTensorVariables... + # https://github.com/pymc-devs/pytensor/issues/396 + if isinstance(var, ScalarSharedVariable): + new_var = shared(var.get_value(borrow=False).item()) + else: + new_var = shared(var.get_value(borrow=False)) + + assert new_var.type == var.type + new_var.name = var.name + new_var.tag = copy(var.tag) + # We can replace input variables by placing them in the memo + memo[var] = new_var + + fgraph = FunctionGraph( + outputs=model_vars, + clone=True, + memo=memo, + copy_orphans=True, + copy_inputs=True, + ) + # Copy model meta-info to fgraph + fgraph._coords = model._coords.copy() + fgraph._dim_lengths = {k: memo.get(v, v) for k, v in model._dim_lengths.items()} + + rvs_to_transforms = model.rvs_to_transforms + named_vars_to_dims = model.named_vars_to_dims + + # Introduce dummy `ModelVar` Ops + free_rvs_to_transforms = {memo[k]: tr for k, tr in rvs_to_transforms.items()} + free_rvs_to_values = {memo[k]: memo[v] for k, v in zip(rvs, value_vars) if k in free_rvs} + observed_rvs_to_values = { + memo[k]: memo[v] for k, v in zip(rvs, value_vars) if k in observed_rvs + } + potentials = [memo[k] for k in potentials] + deterministics = [memo[k] for k in deterministics] + named_vars = [memo[k] for k in other_named_vars + named_value_vars] + + vars = fgraph.outputs + new_vars = [] + for var in vars: + dims = named_vars_to_dims.get(var.name, ()) + if var in free_rvs_to_values: + new_var = model_free_rv( + var, free_rvs_to_values[var], free_rvs_to_transforms[var], *dims + ) + elif var in observed_rvs_to_values: + new_var = model_observed_rv(var, observed_rvs_to_values[var], *dims) + elif var in potentials: + new_var = model_potential(var, *dims) + elif var in deterministics: + new_var = model_deterministic(var, *dims) + elif var in named_vars: + new_var = model_named(var, *dims) + else: + # Unnamed value variables + new_var = var + new_vars.append(new_var) + + replacements = tuple(zip(vars, new_vars)) + toposort_replace(fgraph, replacements, reverse=True) + + # Reference model vars in memo + inverse_memo = {v: k for k, v in memo.items()} + for var, model_var in replacements: + if not inlined_views and ( + model_var.owner and isinstance(model_var.owner.op, (ModelDeterministic, ModelNamed)) + ): + # Ignore extra identity that will be removed at the end + var = var.owner.inputs[0] + original_var = inverse_memo[var] + memo[original_var] = model_var + + # Remove the last outputs corresponding to unnamed value variables, now that they are graph inputs + first_idx_to_remove = len(fgraph.outputs) - len(unnamed_value_vars) + for _ in unnamed_value_vars: + fgraph.remove_output(first_idx_to_remove) + + # Now that we have Deterministic dummy Ops, we remove the noisy `Identity`s from the graph + remove_identity_rewrite.apply(fgraph) + + return fgraph, memo + + +def model_from_fgraph(fgraph: FunctionGraph) -> Model: + """Convert FunctionGraph to PyMC model. + + This requires nodes to be properly tagged with `ModelVar` dummy Ops. + + See: fgraph_from_model + """ + + def first_non_model_var(var): + if var.owner and isinstance(var.owner.op, ModelVar): + new_var = var.owner.inputs[0] + return first_non_model_var(new_var) + else: + return var + + model = Model() + if model.parent is not None: + raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context") + model._coords = getattr(fgraph, "_coords", {}) + model._dim_lengths = getattr(fgraph, "_dim_lengths", {}) + + # Replace dummy `ModelVar` Ops by the underlying variables, + fgraph = fgraph.clone() + model_dummy_vars = [ + model_node.outputs[0] + for model_node in fgraph.toposort() + if isinstance(model_node.op, ModelVar) + ] + model_dummy_vars_to_vars = { + # Deterministics could refer to other model variables directly, + # We make sure to replace them by the first non-model variable + dummy_var: first_non_model_var(dummy_var.owner.inputs[0]) + for dummy_var in model_dummy_vars + } + toposort_replace(fgraph, tuple(model_dummy_vars_to_vars.items())) + + # Populate new PyMC model mappings + for model_var in model_dummy_vars: + if isinstance(model_var.owner.op, ModelFreeRV): + var, value, *dims = model_var.owner.inputs + transform = model_var.owner.op.transform + model.free_RVs.append(var) + # PyMC does not allow setting transform when we pass a value_var. Why? + model.create_value_var(var, transform=None, value_var=value) + model.rvs_to_transforms[var] = transform + model.set_initval(var, initval=None) + elif isinstance(model_var.owner.op, ModelObservedRV): + var, value, *dims = model_var.owner.inputs + model.observed_RVs.append(var) + model.create_value_var(var, transform=None, value_var=value) + elif isinstance(model_var.owner.op, ModelPotential): + var, *dims = model_var.owner.inputs + model.potentials.append(var) + elif isinstance(model_var.owner.op, ModelDeterministic): + var, *dims = model_var.owner.inputs + # If a Deterministic is a direct view on an RV, copy it + if var in model.basic_RVs: + var = var.copy() + model.deterministics.append(var) + elif isinstance(model_var.owner.op, ModelNamed): + var, *dims = model_var.owner.inputs + else: + raise TypeError(f"Unexpected ModelVar type {type(model_var)}") + + var.name = model_var.name + dims = [dim.data for dim in dims] if dims else None + model.add_named_variable(var, dims=dims) + + return model + + +def clone_model(model: Model) -> Model: + """Clone a PyMC model. + + Recreates a PyMC model with clones of the original variables. + Shared variables will point to the same container but be otherwise different objects. + Constants are not cloned. + + + Examples + -------- + .. code-block:: python + + import pymc as pm + from pymc_experimental.utils import clone_model + + with pm.Model() as m: + p = pm.Beta("p", 1, 1) + x = pm.Bernoulli("x", p=p, shape=(3,)) + + with clone_model(m) as clone_m: + # Access cloned variables by name + clone_x = clone_m["x"] + + # z will be part of clone_m but not m + z = pm.Deterministic("z", clone_x + 1) + + """ + return model_from_fgraph(fgraph_from_model(model)[0]) + + +def extract_dims(var) -> Tuple: + dims = () + node = var.owner + if node and isinstance(node.op, ModelVar): + if isinstance(node.op, ModelValuedVar): + dims = node.inputs[2:] + else: + dims = node.inputs[1:] + return dims + + +__all__ = ( + "fgraph_from_model", + "model_from_fgraph", + "clone_model", +) diff --git a/pymc/model/transform/__init__.py b/pymc/model/transform/__init__.py new file mode 100644 index 00000000000..fcb37fe8c79 --- /dev/null +++ b/pymc/model/transform/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py new file mode 100644 index 00000000000..849156946db --- /dev/null +++ b/pymc/model/transform/basic.py @@ -0,0 +1,59 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Sequence, Union + +from pytensor import Variable +from pytensor.graph import ancestors + +from pymc import Model +from pymc.model.fgraph import ( + ModelObservedRV, + ModelVar, + fgraph_from_model, + model_from_fgraph, +) + +ModelVariable = Union[Variable, str] + + +def prune_vars_detached_from_observed(model: Model) -> Model: + """Prune model variables that are not related to any observed variable in the Model.""" + + # Potentials are ambiguous as whether they correspond to likelihood or prior terms, + # We simply raise for now + if model.potentials: + raise NotImplementedError("Pruning not implemented for models with Potentials") + + fgraph, _ = fgraph_from_model(model, inlined_views=True) + observed_vars = ( + out + for node in fgraph.apply_nodes + if isinstance(node.op, ModelObservedRV) + for out in node.outputs + ) + ancestor_nodes = {var.owner for var in ancestors(observed_vars)} + nodes_to_remove = { + node + for node in fgraph.apply_nodes + if isinstance(node.op, ModelVar) and node not in ancestor_nodes + } + for node_to_remove in nodes_to_remove: + fgraph.remove_node(node_to_remove) + return model_from_fgraph(fgraph) + + +def parse_vars(model: Model, vars: Union[ModelVariable, Sequence[ModelVariable]]) -> List[Variable]: + if not isinstance(vars, (list, tuple)): + vars = (vars,) + return [model[var] if isinstance(var, str) else var for var in vars] diff --git a/pymc/model/transform/conditioning.py b/pymc/model/transform/conditioning.py new file mode 100644 index 00000000000..3e81c4fb9c5 --- /dev/null +++ b/pymc/model/transform/conditioning.py @@ -0,0 +1,402 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings + +from typing import Any, List, Mapping, Optional, Sequence, Union + +from pytensor import Variable +from pytensor.graph import ancestors +from pytensor.graph.basic import walk +from pytensor.graph.op import HasInnerGraph +from pytensor.tensor import TensorVariable +from pytensor.tensor.random.op import RandomVariable + +from pymc import Model +from pymc.logprob.transforms import RVTransform +from pymc.model.fgraph import ( + ModelDeterministic, + ModelFreeRV, + extract_dims, + fgraph_from_model, + model_deterministic, + model_free_rv, + model_from_fgraph, + model_named, + model_observed_rv, +) +from pymc.model.transform.basic import ( + ModelVariable, + parse_vars, + prune_vars_detached_from_observed, +) +from pymc.pytensorf import _replace_vars_in_graphs, toposort_replace +from pymc.util import get_transformed_name, get_untransformed_name + + +def observe( + model: Model, vars_to_observations: Mapping[Union["str", TensorVariable], Any] +) -> Model: + """Convert free RVs or Deterministics to observed RVs. + + Parameters + ---------- + model: PyMC Model + vars_to_observations: Dict of variable or name to TensorLike + Dictionary that maps model variables (or names) to observed values. + Observed values must have a shape and data type that is compatible + with the original model variable. + + Returns + ------- + new_model: PyMC model + A distinct PyMC model with the relevant variables observed. + All remaining variables are cloned and can be retrieved via `new_model["var_name"]`. + + Examples + -------- + .. code-block:: python + + import pymc as pm + from pymc_experimental.model_transform.conditioning import observe + + with pm.Model() as m: + x = pm.Normal("x") + y = pm.Normal("y", x) + z = pm.Normal("z", y) + + m_new = observe(m, {y: 0.5}) + + Deterministic variables can also be observed. + This relies on PyMC ability to infer the logp of the underlying expression + + .. code-block:: python + + import pymc as pm + from pymc_experimental.model_transform.conditioning import observe + + with pm.Model() as m: + x = pm.Normal("x") + y = pm.Normal.dist(x, shape=(5,)) + y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1)) + + new_m = observe(m, {y_censored: [0.9, 0.5, 0.3, 1, 1]}) + + + """ + vars_to_observations = { + model[var] if isinstance(var, str) else var: obs + for var, obs in vars_to_observations.items() + } + + valid_model_vars = set(model.free_RVs + model.deterministics) + if any(var not in valid_model_vars for var in vars_to_observations): + raise ValueError(f"At least one var is not a free variable or deterministic in the model") + + fgraph, memo = fgraph_from_model(model) + + replacements = {} + for var, obs in vars_to_observations.items(): + model_var = memo[var] + + # Just a sanity check + assert isinstance(model_var.owner.op, (ModelFreeRV, ModelDeterministic)) + assert model_var in fgraph.variables + + var = model_var.owner.inputs[0] + var.name = model_var.name + dims = extract_dims(model_var) + model_obs_rv = model_observed_rv(var, var.type.filter_variable(obs), *dims) + replacements[model_var] = model_obs_rv + + toposort_replace(fgraph, tuple(replacements.items())) + + return model_from_fgraph(fgraph) + + +def replace_vars_in_graphs(graphs: Sequence[TensorVariable], replacements) -> List[TensorVariable]: + def replacement_fn(var, inner_replacements): + if var in replacements: + inner_replacements[var] = replacements[var] + + # Handle root inputs as those will never be passed to the replacement_fn + for inp in var.owner.inputs: + if inp.owner is None and inp in replacements: + inner_replacements[inp] = replacements[inp] + + return [var] + + replaced_graphs, _ = _replace_vars_in_graphs(graphs=graphs, replacement_fn=replacement_fn) + return replaced_graphs + + +def rvs_in_graph(vars: Sequence[Variable]) -> bool: + """Check if there are any rvs in the graph of vars""" + + from pymc.distributions.distribution import SymbolicRandomVariable + + def expand(r): + owner = r.owner + if owner: + inputs = list(reversed(owner.inputs)) + + if isinstance(owner.op, HasInnerGraph): + inputs += owner.op.inner_outputs + + return inputs + + return any( + node + for node in walk(vars, expand, False) + if node.owner and isinstance(node.owner.op, (RandomVariable, SymbolicRandomVariable)) + ) + + +def do( + model: Model, + vars_to_interventions: Mapping[Union["str", TensorVariable], Any], + prune_vars=False, +) -> Model: + """Replace model variables by intervention variables. + + Intervention variables will either show up as `Data` or `Deterministics` in the new model, + depending on whether they depend on other RandomVariables or not. + + Parameters + ---------- + model: PyMC Model + vars_to_interventions: Dict of variable or name to TensorLike + Dictionary that maps model variables (or names) to intervention expressions. + Intervention expressions must have a shape and data type that is compatible + with the original model variable. + prune_vars: bool, defaults to False + Whether to prune model variables that are not connected to any observed variables, + after the interventions. + + Returns + ------- + new_model: PyMC model + A distinct PyMC model with the relevant variables replaced by the intervention expressions. + All remaining variables are cloned and can be retrieved via `new_model["var_name"]`. + + Examples + -------- + .. code-block:: python + + import pymc as pm + from pymc_experimental.model_transform.conditioning import do + + with pm.Model() as m: + x = pm.Normal("x", 0, 1) + y = pm.Normal("y", x, 1) + z = pm.Normal("z", y + x, 1) + + # Dummy posterior, same as calling `pm.sample` + idata_m = az.from_dict({rv.name: [pm.draw(rv, draws=500)] for rv in [x, y, z]}) + + # Replace `y` by a constant `100.0` + m_do = do(m, {y: 100.0}) + with m_do: + idata_do = pm.sample_posterior_predictive(idata_m, var_names="z") + + """ + do_mapping = {} + for var, obs in vars_to_interventions.items(): + if isinstance(var, str): + var = model[var] + try: + do_mapping[var] = var.type.filter_variable(obs) + except TypeError as err: + raise TypeError( + "Incompatible replacement type. Make sure the shape and datatype of the interventions match the original variables" + ) from err + + if any(var not in model.named_vars.values() for var in do_mapping): + raise ValueError(f"At least one var is not a named variable in the model") + + fgraph, memo = fgraph_from_model(model, inlined_views=True) + + # We need the interventions defined in terms of the IR fgraph representation, + # In case they reference other variables in the model + ir_interventions = replace_vars_in_graphs(list(do_mapping.values()), replacements=memo) + + replacements = {} + for var, intervention in zip(do_mapping, ir_interventions): + model_var = memo[var] + + # Just a sanity check + assert model_var in fgraph.variables + + # If the intervention references the original variable we must give it a different name + if model_var in ancestors([intervention]): + intervention.name = f"do_{model_var.name}" + warnings.warn( + f"Intervention expression references the variable that is being intervened: {model_var.name}. " + f"Intervention will be given the name: {intervention.name}" + ) + else: + intervention.name = model_var.name + dims = extract_dims(model_var) + # If there are any RVs in the graph we introduce the intervention as a deterministic + if rvs_in_graph([intervention]): + new_var = model_deterministic(intervention.copy(name=intervention.name), *dims) + # Otherwise as a named variable (Constant or Shared data) + else: + new_var = model_named(intervention, *dims) + + replacements[model_var] = new_var + + # Replace variables by interventions + toposort_replace(fgraph, tuple(replacements.items())) + + model = model_from_fgraph(fgraph) + if prune_vars: + return prune_vars_detached_from_observed(model) + return model + + +def change_value_transforms( + model: Model, + vars_to_transforms: Mapping[ModelVariable, Union[RVTransform, None]], +) -> Model: + """Change the value variables transforms in the model + + Parameters + ---------- + model : Model + vars_to_transforms : Dict + Dictionary that maps RVs to new transforms to be applied to the respective value variables + + Returns + ------- + new_model : Model + Model with the updated transformed value variables + + Examples + -------- + Extract untransformed space Hessian after finding transformed space MAP + + .. code-block:: python + + import pymc as pm + from pymc.distributions.transforms import logodds + from pymc_experimental.model_transform.conditioning import change_value_transforms + + with pm.Model() as base_m: + p = pm.Uniform("p", 0, 1, transform=None) + w = pm.Binomial("w", n=9, p=p, observed=6) + + with change_value_transforms(base_m, {"p": logodds}) as transformed_p: + mean_q = pm.find_MAP() + + with change_value_transforms(transformed_p, {"p": None}) as untransformed_p: + new_p = untransformed_p['p'] + std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0] + + print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}") + # Mean, Standard deviation + # p 0.67, 0.16 + + """ + vars_to_transforms = { + parse_vars(model, var)[0]: transform for var, transform in vars_to_transforms.items() + } + + if set(vars_to_transforms.keys()) - set(model.free_RVs): + raise ValueError(f"All keys must be free variables in the model: {model.free_RVs}") + + fgraph, memo = fgraph_from_model(model) + + vars_to_transforms = {memo[var]: transform for var, transform in vars_to_transforms.items()} + replacements = {} + for node in fgraph.apply_nodes: + if not isinstance(node.op, ModelFreeRV): + continue + + [dummy_rv] = node.outputs + if dummy_rv not in vars_to_transforms: + continue + + transform = vars_to_transforms[dummy_rv] + + rv, value, *dims = node.inputs + + new_value = rv.type() + try: + untransformed_name = get_untransformed_name(value.name) + except ValueError: + untransformed_name = value.name + if transform: + new_name = get_transformed_name(untransformed_name, transform) + else: + new_name = untransformed_name + new_value.name = new_name + + new_dummy_rv = model_free_rv(rv, new_value, transform, *dims) + replacements[dummy_rv] = new_dummy_rv + + toposort_replace(fgraph, tuple(replacements.items())) + return model_from_fgraph(fgraph) + + +def remove_value_transforms( + model: Model, + vars: Optional[Sequence[ModelVariable]] = None, +) -> Model: + """Remove the value variables transforms in the model + + Parameters + ---------- + model : Model + vars : Model variables, optional + Model variables for which to remove transforms. Defaults to all transformed variables + + Returns + ------- + new_model : Model + Model with the removed transformed value variables + + Examples + -------- + Extract untransformed space Hessian after finding transformed space MAP + + .. code-block:: python + + import pymc as pm + from pymc_experimental.model_transform.conditioning import remove_value_transforms + + with pm.Model() as transformed_m: + p = pm.Uniform("p", 0, 1) + w = pm.Binomial("w", n=9, p=p, observed=6) + mean_q = pm.find_MAP() + + with remove_value_transforms(transformed_m) as untransformed_m: + new_p = untransformed_m["p"] + std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0] + print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}") + + # Mean, Standard deviation + # p 0.67, 0.16 + + """ + if vars is None: + vars = model.free_RVs + return change_value_transforms(model, {var: None for var in vars}) + + +__all__ = ( + "change_value_transforms", + "do", + "observe", + "remove_value_transforms", +) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 066ba233c5b..f79f5461f99 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -35,7 +35,7 @@ from pytensor import scalar from pytensor.compile import Function, Mode, get_mode from pytensor.gradient import grad -from pytensor.graph import node_rewriter, rewrite_graph +from pytensor.graph import Type, node_rewriter, rewrite_graph from pytensor.graph.basic import ( Apply, Constant, @@ -1234,3 +1234,46 @@ def rewrite_pregrad(graph): pre-grad. """ return rewrite_graph(graph, include=("canonicalize", "stabilize")) + + +class StringType(Type[str]): + def clone(self, **kwargs): + return type(self)() + + def filter(self, x, strict=False, allow_downcast=None): + if isinstance(x, str): + return x + else: + raise TypeError("Expected a string!") + + def __str__(self): + return "string" + + @staticmethod + def may_share_memory(a, b): + return isinstance(a, str) and a is b + + +stringtype = StringType() + + +class StringConstant(Constant): + pass + + +@pytensor._as_symbolic.register(str) +def as_symbolic_string(x, **kwargs): + return StringConstant(stringtype, x) + + +def toposort_replace( + fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]], reverse: bool = False +) -> None: + """Replace multiple variables in topological order.""" + toposort = fgraph.toposort() + sorted_replacements = sorted( + replacements, + key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1, + reverse=reverse, + ) + fgraph.replace_all(sorted_replacements, import_missing=True) diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 0dc73ef2a08..77da8e8a0fe 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -39,7 +39,10 @@ pymc/logprob/tensor.py pymc/logprob/transforms.py pymc/logprob/utils.py -pymc/model.py +pymc/model/core.py +pymc/model/fgraph.py +pymc/model/transform/basic.py +pymc/model/transform/conditioning.py pymc/model_graph.py pymc/printing.py pymc/pytensorf.py diff --git a/tests/model/__init__.py b/tests/model/__init__.py new file mode 100644 index 00000000000..fcb37fe8c79 --- /dev/null +++ b/tests/model/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/test_model.py b/tests/model/test_core.py similarity index 99% rename from tests/test_model.py rename to tests/model/test_core.py index 490047d19fa..33a948c109c 100644 --- a/tests/test_model.py +++ b/tests/model/test_core.py @@ -46,10 +46,9 @@ from pymc.distributions.distribution import PartialObservedRV from pymc.distributions.transforms import log, simplex from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning -from pymc.logprob.basic import conditional_logp, transformed_conditional_logp +from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.transforms import IntervalTransform from pymc.model import Point, ValueGradFunction, modelcontext -from pymc.model_graph import model_to_graphviz from pymc.util import _FutureWarningValidatingScratchpad from pymc.variational.minibatch_rv import MinibatchRandomVariable from tests.models import simple_model @@ -1676,7 +1675,7 @@ def school_model(J: int) -> pm.Model: ) def test_graphviz_call_function(self, var_names) -> None: model = self.school_model(J=8) - with patch("pymc.model.model_to_graphviz") as mock_model_to_graphviz: + with patch("pymc.model.core.model_to_graphviz") as mock_model_to_graphviz: model.to_graphviz(var_names=var_names) mock_model_to_graphviz.assert_called_once_with( model=model, var_names=var_names, formatting="plain" diff --git a/tests/model/test_fgraph.py b/tests/model/test_fgraph.py new file mode 100644 index 00000000000..eaa73afc182 --- /dev/null +++ b/tests/model/test_fgraph.py @@ -0,0 +1,353 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytensor.tensor as pt +import pytest + +from pytensor import config, shared +from pytensor.graph import Constant, FunctionGraph, node_rewriter +from pytensor.graph.rewriting.basic import in2out +from pytensor.tensor.exceptions import NotScalarConstantError + +import pymc as pm + +from pymc.model.fgraph import ( + ModelDeterministic, + ModelFreeRV, + ModelNamed, + ModelObservedRV, + ModelPotential, + ModelVar, + clone_model, + fgraph_from_model, + model_deterministic, + model_free_rv, + model_from_fgraph, +) + + +def test_basic(): + """Test we can convert from a PyMC Model to a FunctionGraph and back""" + with pm.Model(coords={"test_dim": range(3)}) as m_old: + x = pm.Normal("x") + y = pm.Deterministic("y", x + 1) + w = pm.HalfNormal("w", pm.math.exp(y)) + z = pm.Normal("z", y, w, observed=[0, 1, 2], dims=("test_dim",)) + pot = pm.Potential("pot", x * 2) + + m_fgraph, memo = fgraph_from_model(m_old) + assert isinstance(m_fgraph, FunctionGraph) + + assert isinstance(memo[x].owner.op, ModelFreeRV) + assert isinstance(memo[y].owner.op, ModelDeterministic) + assert isinstance(memo[w].owner.op, ModelFreeRV) + assert isinstance(memo[z].owner.op, ModelObservedRV) + assert isinstance(memo[pot].owner.op, ModelPotential) + + m_new = model_from_fgraph(m_fgraph) + assert isinstance(m_new, pm.Model) + + assert m_new.coords == {"test_dim": tuple(range(3))} + assert m_new._dim_lengths["test_dim"].eval() == 3 + assert m_new.named_vars_to_dims == {"z": ["test_dim"]} + + named_vars = {"x", "y", "w", "z", "pot"} + assert set(m_new.named_vars) == named_vars + for named_var in named_vars: + assert m_new[named_var] is not m_old[named_var] + for value_new, value_old in zip(m_new.rvs_to_values.values(), m_old.rvs_to_values.values()): + # Constants are not cloned + if not isinstance(value_new, Constant): + assert value_new is not value_old + assert m_new["x"] in m_new.free_RVs + assert m_new["w"] in m_new.free_RVs + assert m_new["y"] in m_new.deterministics + assert m_new["z"] in m_new.observed_RVs + assert m_new["pot"] in m_new.potentials + assert m_new.rvs_to_transforms[m_new["x"]] is None + assert m_new.rvs_to_transforms[m_new["w"]] is pm.distributions.transforms.log + assert m_new.rvs_to_transforms[m_new["z"]] is None + + # Test random + new_y_draw, new_z_draw = pm.draw([m_new["y"], m_new["z"]], draws=5, random_seed=1) + old_y_draw, old_z_draw = pm.draw([m_old["y"], m_old["z"]], draws=5, random_seed=1) + np.testing.assert_array_equal(new_y_draw, old_y_draw) + np.testing.assert_array_equal(new_z_draw, old_z_draw) + + # Test logp + ip = m_new.initial_point() + np.testing.assert_equal( + m_new.compile_logp()(ip), + m_old.compile_logp()(ip), + ) + + +def same_storage(shared_1, shared_2) -> bool: + """Check if two shared variables have the same storage containers (i.e., they point to the same memory).""" + return shared_1.container.storage is shared_2.container.storage + + +@pytest.mark.parametrize("inline_views", (False, True)) +def test_data(inline_views): + """Test shared RNGs, MutableData, ConstantData and dim lengths are handled correctly. + + All model-related shared variables should be copied to become independent across models. + """ + with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old: + x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",)) + y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",)) + b0 = pm.ConstantData("b0", np.zeros((1,))) + b1 = pm.DiracDelta("b1", 1.0) + mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",)) + obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",)) + + m_fgraph, memo = fgraph_from_model(m_old, inlined_views=inline_views) + assert isinstance(memo[x].owner.op, ModelNamed) + assert isinstance(memo[y].owner.op, ModelNamed) + assert isinstance(memo[b0].owner.op, ModelNamed) + mu_inp = memo[mu].owner.inputs[0] + obs = memo[obs] + if not inline_views: + # Add(b0, Mul(FreeRV(b1), x) not Add(Named(b0), Mul(FreeRV(b1), Named(x)) + assert mu_inp.owner.inputs[0] is memo[b0].owner.inputs[0] + assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x].owner.inputs[0] + # ObservedRV(obs, y, *dims) not ObservedRV(obs, Named(y), *dims) + assert obs.owner.inputs[1] is memo[y].owner.inputs[0] + else: + assert mu_inp.owner.inputs[0] is memo[b0] + assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x] + assert obs.owner.inputs[1] is memo[y] + + m_new = model_from_fgraph(m_fgraph) + + # The rv-data mapping is preserved + assert m_new.rvs_to_values[m_new["obs"]] is m_new["y"] + + # ConstantData is still accessible as a model variable + np.testing.assert_array_equal(m_new["b0"], m_old["b0"]) + + # Shared model variables, dim lengths, and rngs are copied and no longer point to the same memory + assert not same_storage(m_new["x"], x) + assert not same_storage(m_new["y"], y) + assert not same_storage(m_new["b1"].owner.inputs[0], b1.owner.inputs[0]) + assert not same_storage(m_new.dim_lengths["test_dim"], m_old.dim_lengths["test_dim"]) + + # Updating model shared variables in new model, doesn't affect old one + with m_new: + pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)}) + assert m_new.dim_lengths["test_dim"].eval() == 2 + assert m_old.dim_lengths["test_dim"].eval() == 3 + np.testing.assert_allclose(pm.draw(m_new["mu"]), [100.0, 200.0]) + np.testing.assert_allclose(pm.draw(m_old["mu"]), [0.0, 1.0, 2.0], atol=1e-6) + + +@config.change_flags(floatX="float64") # Avoid downcasting Ops in the graph +def test_shared_variable(): + """Test that user defined shared variables (other than RNGs) aren't copied.""" + x = shared(np.array([1, 2, 3.0]), name="x") + y = shared(np.array([1, 2, 3.0]), name="y") + + with pm.Model() as m_old: + test = pm.Normal("test", mu=x, observed=y) + + assert test.owner.inputs[3] is x + assert m_old.rvs_to_values[test] is y + + m_new = clone_model(m_old) + test_new = m_new["test"] + # Shared Variables are cloned but still point to the same memory + assert test_new.owner.inputs[3] is not x + assert m_new.rvs_to_values[test_new] is not y + assert same_storage(test_new.owner.inputs[3], x) + assert same_storage(m_new.rvs_to_values[test_new], y) + + +@pytest.mark.parametrize("inline_views", (False, True)) +def test_deterministics(inline_views): + """Test handling of deterministics. + + We don't want Deterministics in the middle of the FunctionGraph, as they would make rewrites cumbersome + However we want them in the middle of Model.basic_RVs, so they display nicely in graphviz + + There is one edge case that has to be considered, when a Deterministic is just a copy of a RV. + In that case we don't bother to reintroduce it in between other Model.basic_RVs + """ + with pm.Model() as m: + x = pm.Normal("x") + mu = pm.Deterministic("mu", pm.math.abs(x)) + sigma = pm.math.exp(x) + pm.Deterministic("sigma", sigma) + y = pm.Normal("y", mu, sigma) + # Special case where the Deterministic + # is a direct view on another model variable + y_ = pm.Deterministic("y_", y) + # Just for kicks, make it a double one! + y__ = pm.Deterministic("y__", y_) + z = pm.Normal("z", y__) + + # Deterministic mu is in the graph of x to y but not sigma + assert m["y"].owner.inputs[3] is m["mu"] + assert m["y"].owner.inputs[4] is not m["sigma"] + + fg, _ = fgraph_from_model(m, inlined_views=inline_views) + + # Check that no Deterministics are in graph of x to y and y to z + x, y, z, det_mu, det_sigma, det_y_, det_y__ = fg.outputs + # [Det(mu), Det(sigma)] + mu = det_mu.owner.inputs[0] + sigma = det_sigma.owner.inputs[0] + assert y.owner.inputs[0].owner.inputs[4] is sigma + assert det_y_ is not det_y__ + assert det_y_.owner.inputs[0] is y + if not inline_views: + # FreeRV(y(mu, sigma)) not FreeRV(y(Det(mu), Det(sigma))) + assert y.owner.inputs[0].owner.inputs[3] is mu + # FreeRV(z(y)) not FreeRV(z(Det(Det(y)))) + assert z.owner.inputs[0].owner.inputs[3] is y + # Det(y), not Det(Det(y)) + assert det_y__.owner.inputs[0] is y + else: + assert y.owner.inputs[0].owner.inputs[3] is det_mu + assert z.owner.inputs[0].owner.inputs[3] is det_y__ + assert det_y__.owner.inputs[0] is det_y_ + + # Both mu and sigma deterministics are now in the graph of x to y + m = model_from_fgraph(fg) + assert m["y"].owner.inputs[3] is m["mu"] + assert m["y"].owner.inputs[4] is m["sigma"] + # But not y_* in y to z, since there was no real Op in between + assert m["z"].owner.inputs[3] is m["y"] + assert m["y_"].owner.inputs[0] is m["y"] + assert m["y__"].owner.inputs[0] is m["y"] + + +def test_context_error(): + """Test that model_from_fgraph fails when called inside a Model context. + + We can't allow it, because the new Model that's returned would be a child of whatever Model context is active. + """ + with pm.Model() as m: + x = pm.Normal("x") + + fg = fgraph_from_model(m) + + with pytest.raises(RuntimeError, match="cannot be called inside a PyMC model context"): + model_from_fgraph(fg) + + +def test_sub_model_error(): + """Test Error is raised when trying to convert a sub-model to fgraph.""" + with pm.Model() as m: + x = pm.Beta("x", 1, 1) + with pm.Model() as sub_m: + y = pm.Normal("y", x) + + nodes = [v for v in fgraph_from_model(m)[0].toposort() if not isinstance(v.op, ModelVar)] + assert len(nodes) == 2 + assert isinstance(nodes[0].op, pm.Beta) + assert isinstance(nodes[1].op, pm.Normal) + + with pytest.raises(ValueError, match="Nested sub-models cannot be converted"): + fgraph_from_model(sub_m) + + +@pytest.fixture() +def non_centered_rewrite(): + @node_rewriter(tracks=[ModelFreeRV]) + def non_centered_param(fgraph: FunctionGraph, node): + """Rewrite that replaces centered normal by non-centered parametrization.""" + + rv, value, *dims = node.inputs + if not isinstance(rv.owner.op, pm.Normal): + return + rng, size, dtype, loc, scale = rv.owner.inputs + + # Only apply rewrite if size information is explicit + if size.ndim == 0: + return None + + try: + is_unit = ( + pt.get_underlying_scalar_constant_value(loc) == 0 + and pt.get_underlying_scalar_constant_value(scale) == 1 + ) + except NotScalarConstantError: + is_unit = False + + # Nothing to do here + if is_unit: + return + + raw_norm = pm.Normal.dist(0, 1, size=size, rng=rng) + raw_norm.name = f"{rv.name}_raw_" + raw_norm_value = raw_norm.clone() + fgraph.add_input(raw_norm_value) + raw_norm = model_free_rv(raw_norm, raw_norm_value, node.op.transform, *dims) + + new_norm = loc + raw_norm * scale + new_norm.name = rv.name + new_norm_det = model_deterministic(new_norm, *dims) + fgraph.add_output(new_norm_det) + + return [new_norm] + + return in2out(non_centered_param) + + +def test_fgraph_rewrite(non_centered_rewrite): + """Test we can apply a simple rewrite to a PyMC Model.""" + + with pm.Model(coords={"subject": range(10)}) as m_old: + group_mean = pm.Normal("group_mean") + group_std = pm.HalfNormal("group_std") + subject_mean = pm.Normal("subject_mean", group_mean, group_std, dims=("subject",)) + obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10), dims=("subject",)) + + fg, _ = fgraph_from_model(m_old) + non_centered_rewrite.apply(fg) + + m_new = model_from_fgraph(fg) + assert m_new.named_vars_to_dims == { + "subject_mean": ["subject"], + "subject_mean_raw_": ["subject"], + "obs": ["subject"], + } + assert set(m_new.named_vars) == { + "group_mean", + "group_std", + "subject_mean_raw_", + "subject_mean", + "obs", + } + assert {rv.name for rv in m_new.free_RVs} == {"group_mean", "group_std", "subject_mean_raw_"} + assert {rv.name for rv in m_new.observed_RVs} == {"obs"} + assert {rv.name for rv in m_new.deterministics} == {"subject_mean"} + + with pm.Model() as m_ref: + group_mean = pm.Normal("group_mean") + group_std = pm.HalfNormal("group_std") + subject_mean_raw = pm.Normal("subject_mean_raw_", 0, 1, shape=(10,)) + subject_mean = pm.Deterministic("subject_mean", group_mean + subject_mean_raw * group_std) + obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10)) + + np.testing.assert_array_equal( + pm.draw(m_new["subject_mean_raw_"], draws=7, random_seed=1), + pm.draw(m_ref["subject_mean_raw_"], draws=7, random_seed=1), + ) + + ip = m_new.initial_point() + np.testing.assert_equal( + m_new.compile_logp()(ip), + m_ref.compile_logp()(ip), + ) diff --git a/tests/model/transform/__init__.py b/tests/model/transform/__init__.py new file mode 100644 index 00000000000..fcb37fe8c79 --- /dev/null +++ b/tests/model/transform/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/model/transform/test_basic.py b/tests/model/transform/test_basic.py new file mode 100644 index 00000000000..484c16740ae --- /dev/null +++ b/tests/model/transform/test_basic.py @@ -0,0 +1,32 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pymc as pm + +from pymc.model.transform.basic import prune_vars_detached_from_observed + + +def test_prune_vars_detached_from_observed(): + with pm.Model() as m: + obs_data = pm.MutableData("obs_data", 0) + a0 = pm.ConstantData("a0", 0) + a1 = pm.Normal("a1", a0) + a2 = pm.Normal("a2", a1) + pm.Normal("obs", a2, observed=obs_data) + + d0 = pm.ConstantData("d0", 0) + d1 = pm.Normal("d1", d0) + + assert set(m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs", "d0", "d1"} + pruned_m = prune_vars_detached_from_observed(m) + assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"} diff --git a/tests/model/transform/test_conditioning.py b/tests/model/transform/test_conditioning.py new file mode 100644 index 00000000000..ce486607a63 --- /dev/null +++ b/tests/model/transform/test_conditioning.py @@ -0,0 +1,310 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import arviz as az +import numpy as np +import pytest + +from pytensor import config + +import pymc as pm + +from pymc.distributions.transforms import logodds +from pymc.model.transform.conditioning import ( + change_value_transforms, + do, + observe, + remove_value_transforms, +) +from pymc.variational.minibatch_rv import create_minibatch_rv + + +def test_observe(): + with pm.Model() as m_old: + x = pm.Normal("x") + y = pm.Normal("y", x) + z = pm.Normal("z", y) + + m_new = observe(m_old, {y: 0.5}) + + assert len(m_new.free_RVs) == 2 + assert len(m_new.observed_RVs) == 1 + assert m_new["x"] in m_new.free_RVs + assert m_new["y"] in m_new.observed_RVs + assert m_new["z"] in m_new.free_RVs + + np.testing.assert_allclose( + m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}), + m_new.compile_logp()({"x": 0.9, "z": 1.4}), + ) + + # Test two substitutions + m_new = observe(m_old, {y: 0.5, z: 1.4}) + + assert len(m_new.free_RVs) == 1 + assert len(m_new.observed_RVs) == 2 + assert m_new["x"] in m_new.free_RVs + assert m_new["y"] in m_new.observed_RVs + assert m_new["z"] in m_new.observed_RVs + + np.testing.assert_allclose( + m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}), + m_new.compile_logp()({"x": 0.9}), + ) + + +def test_observe_minibatch(): + data = np.zeros((100,), dtype=config.floatX) + batch_size = 10 + with pm.Model() as m_old: + x = pm.Normal("x") + y = pm.Normal("y", x) + # Minibatch RVs are usually created with `total_size` kwarg + z_raw = pm.Normal.dist(y, shape=batch_size) + mb_z = create_minibatch_rv(z_raw, total_size=data.shape) + m_old.register_rv(mb_z, name="mb_z") + + mb_data = pm.Minibatch(data, batch_size=batch_size) + m_new = observe(m_old, {mb_z: mb_data}) + + assert len(m_new.free_RVs) == 2 + assert len(m_new.observed_RVs) == 1 + assert m_new["x"] in m_new.free_RVs + assert m_new["y"] in m_new.free_RVs + assert m_new["mb_z"] in m_new.observed_RVs + + np.testing.assert_allclose( + m_old.compile_logp()({"x": 0.9, "y": 0.5, "mb_z": np.zeros(10)}), + m_new.compile_logp()({"x": 0.9, "y": 0.5}), + ) + + +def test_observe_deterministic(): + y_censored_obs = np.array([0.9, 0.5, 0.3, 1, 1], dtype=config.floatX) + + with pm.Model() as m_old: + x = pm.Normal("x") + y = pm.Normal.dist(x, shape=(5,)) + y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1)) + + m_new = observe(m_old, {y_censored: y_censored_obs}) + + with pm.Model() as m_ref: + x = pm.Normal("x") + pm.Censored("y_censored", pm.Normal.dist(x), lower=-1, upper=1, observed=y_censored_obs) + + +def test_observe_dims(): + with pm.Model(coords={"test_dim": range(5)}) as m_old: + x = pm.Normal("x", dims="test_dim") + + m_new = observe(m_old, {x: np.arange(5, dtype=config.floatX)}) + assert m_new.named_vars_to_dims["x"] == ["test_dim"] + + +def test_do(): + rng = np.random.default_rng(seed=435) + with pm.Model() as m_old: + x = pm.Normal("x", 0, 1e-3) + y = pm.Normal("y", x, 1e-3) + z = pm.Normal("z", y + x, 1e-3) + + assert -5 < pm.draw(z, random_seed=rng) < 5 + + m_new = do(m_old, {y: x + 100}) + + assert len(m_new.free_RVs) == 2 + assert m_new["x"] in m_new.free_RVs + assert m_new["y"] in m_new.deterministics + assert m_new["z"] in m_new.free_RVs + + assert 95 < pm.draw(m_new["z"], random_seed=rng) < 105 + + # Test two substitutions + with m_old: + switch = pm.MutableData("switch", 1) + m_new = do(m_old, {y: 100 * switch, x: 100 * switch}) + + assert len(m_new.free_RVs) == 1 + assert m_new["y"] not in m_new.deterministics + assert m_new["x"] not in m_new.deterministics + assert m_new["z"] in m_new.free_RVs + + assert 195 < pm.draw(m_new["z"], random_seed=rng) < 205 + with m_new: + pm.set_data({"switch": 0}) + assert -5 < pm.draw(m_new["z"], random_seed=rng) < 5 + + +def test_do_posterior_predictive(): + with pm.Model() as m: + x = pm.Normal("x", 0, 1) + y = pm.Normal("y", x, 1) + z = pm.Normal("z", y + x, 1e-3) + + # Dummy posterior + idata_m = az.from_dict( + { + "x": np.full((2, 500), 25), + "y": np.full((2, 500), np.nan), + "z": np.full((2, 500), np.nan), + } + ) + + # Replace `y` by a constant `100.0` + m_do = do(m, {y: 100.0}) + with m_do: + idata_do = pm.sample_posterior_predictive(idata_m, var_names="z") + + assert 120 < idata_do.posterior_predictive["z"].mean() < 130 + + +@pytest.mark.parametrize("mutable", (False, True)) +def test_do_constant(mutable): + rng = np.random.default_rng(seed=122) + with pm.Model() as m: + x = pm.Data("x", 0, mutable=mutable) + y = pm.Normal("y", x, 1e-3) + + do_m = do(m, {x: 105}) + assert pm.draw(do_m["y"], random_seed=rng) > 100 + + +def test_do_deterministic(): + rng = np.random.default_rng(seed=435) + with pm.Model() as m: + x = pm.Normal("x", 0, 1e-3) + y = pm.Deterministic("y", x + 105) + z = pm.Normal("z", y, 1e-3) + + do_m = do(m, {"z": x - 105}) + assert pm.draw(do_m["z"], random_seed=rng) < 100 + + +def test_do_dims(): + coords = {"test_dim": range(10)} + with pm.Model(coords=coords) as m: + x = pm.Normal("x", dims="test_dim") + y = pm.Deterministic("y", x + 5, dims="test_dim") + + do_m = do( + m, + {"x": np.zeros(10, dtype=config.floatX)}, + ) + assert do_m.named_vars_to_dims["x"] == ["test_dim"] + + do_m = do( + m, + {"y": np.zeros(10, dtype=config.floatX)}, + ) + assert do_m.named_vars_to_dims["y"] == ["test_dim"] + + +@pytest.mark.parametrize("prune", (False, True)) +def test_do_prune(prune): + with pm.Model() as m: + x0 = pm.ConstantData("x0", 0) + x1 = pm.ConstantData("x1", 0) + y = pm.Normal("y") + y_det = pm.Deterministic("y_det", y + x0) + z = pm.Normal("z", y_det) + llike = pm.Normal("llike", z + x1, observed=0) + + orig_named_vars = {"x0", "x1", "y", "y_det", "z", "llike"} + assert set(m.named_vars) == orig_named_vars + + do_m = do(m, {y_det: x0 + 5}, prune_vars=prune) + if prune: + assert set(do_m.named_vars) == {"x0", "x1", "y_det", "z", "llike"} + else: + assert set(do_m.named_vars) == orig_named_vars + + do_m = do(m, {z: 0.5}, prune_vars=prune) + if prune: + assert set(do_m.named_vars) == {"x1", "z", "llike"} + else: + assert set(do_m.named_vars) == orig_named_vars + + +def test_do_self_reference(): + """Check we can replace a variable by an expression that refers to the same variable.""" + with pm.Model() as m: + x = pm.Normal("x", 0, 1) + + with pytest.warns( + UserWarning, + match="Intervention expression references the variable that is being intervened", + ): + new_m = do(m, {x: x + 100}) + + x = new_m["x"] + do_x = new_m["do_x"] + draw_x, draw_do_x = pm.draw([x, do_x], draws=5) + np.testing.assert_allclose(draw_x + 100, draw_do_x) + + +def test_change_value_transforms(): + with pm.Model() as base_m: + p = pm.Uniform("p", 0, 1, transform=None) + w = pm.Binomial("w", n=9, p=p, observed=6) + assert base_m.rvs_to_transforms[p] is None + assert base_m.rvs_to_values[p].name == "p" + + with change_value_transforms(base_m, {"p": logodds}) as transformed_p: + new_p = transformed_p["p"] + assert transformed_p.rvs_to_transforms[new_p] == logodds + assert transformed_p.rvs_to_values[new_p].name == "p_logodds__" + mean_q = pm.find_MAP(progressbar=False) + + with change_value_transforms(transformed_p, {"p": None}) as untransformed_p: + new_p = untransformed_p["p"] + assert untransformed_p.rvs_to_transforms[new_p] is None + assert untransformed_p.rvs_to_values[new_p].name == "p" + std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0] + + np.testing.assert_allclose(np.round(mean_q["p"], 2), 0.67) + np.testing.assert_allclose(np.round(std_q[0], 2), 0.16) + + +def test_change_value_transforms_error(): + with pm.Model() as m: + x = pm.Uniform("x", observed=5.0) + + with pytest.raises(ValueError, match="All keys must be free variables in the model"): + change_value_transforms(m, {x: logodds}) + + +def test_remove_value_transforms(): + with pm.Model() as base_m: + p = pm.Uniform("p", transform=logodds) + q = pm.Uniform("q", transform=logodds) + + new_m = remove_value_transforms(base_m) + new_p = new_m["p"] + new_q = new_m["q"] + assert new_m.rvs_to_transforms == {new_p: None, new_q: None} + + new_m = remove_value_transforms(base_m, [p, q]) + new_p = new_m["p"] + new_q = new_m["q"] + assert new_m.rvs_to_transforms == {new_p: None, new_q: None} + + new_m = remove_value_transforms(base_m, [p]) + new_p = new_m["p"] + new_q = new_m["q"] + assert new_m.rvs_to_transforms == {new_p: None, new_q: logodds} + + new_m = remove_value_transforms(base_m, ["q"]) + new_p = new_m["p"] + new_q = new_m["q"] + assert new_m.rvs_to_transforms == {new_p: logodds, new_q: None}