Skip to content

Commit

Permalink
Support for PyMC >= 5 (#38)
Browse files Browse the repository at this point in the history
* Update isort version for pre-commit

PyCQA/isort#2108

* Require pymc >= 5.0.0 and replace aesara by pytensor

* Update testval to initval in utils_test.py

testval is deprecated now

* Add tests covering bounded (uniform) variables and lists of variables for utils.eval_in_model()

Bounded variables are transformed by PyMC, which has implications for rvs_to_values I think
Lists of variables are handled differently than single variables by utils.Evaluator

* Replace `pymc.pytensorf.rvs_to_value_vars()` by `pymc.Model().replace_rvs_by_values()`

These should be equivalent. `pymc.pytensorf.rvs_to_value_vars()` requires extra argument, but the model method fills them automatically.
  • Loading branch information
vandalt authored Nov 1, 2023
1 parent efec0cf commit 6ea4f8e
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
exclude: docs/tutorials

- repo: https://github.com/PyCQA/isort
rev: "5.10.1"
rev: "5.11.5"
hooks:
- id: isort
args: []
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"Programming Language :: Python",
"Programming Language :: Python :: 3",
]
INSTALL_REQUIRES = ["pymc"]
INSTALL_REQUIRES = ["pymc >= 5.0.0"]
EXTRA_REQUIRE = {
"test": ["pytest"],
"notebooks": [
Expand Down
12 changes: 6 additions & 6 deletions src/pymc_ext/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import warnings

import aesara.tensor as at
import numpy as np
import pymc as pm
import pytensor.tensor as pt


def angle(name, *, regularization=10.0, **kwargs):
Expand All @@ -21,15 +21,15 @@ def angle(name, *, regularization=10.0, **kwargs):
value of ``10.0`` is a good starting point.
"""
shape = kwargs.get("shape", ())
initval = kwargs.pop("initval", at.broadcast_to(0.0, shape))
initval = kwargs.pop("initval", pt.broadcast_to(0.0, shape))
x1 = pm.Normal(f"__{name}_angle1", initval=np.sin(initval), **kwargs)
x2 = pm.Normal(f"__{name}_angle2", initval=np.cos(initval), **kwargs)
if regularization is not None:
pm.Potential(
f"__{name}_regularization",
regularization * at.log(x1**2 + x2**2),
regularization * pt.log(x1**2 + x2**2),
)
return pm.Deterministic(name, at.arctan2(x1, x2))
return pm.Deterministic(name, pt.arctan2(x1, x2))


def unit_disk(name_x, name_y, **kwargs):
Expand Down Expand Up @@ -57,6 +57,6 @@ def unit_disk(name_x, name_y, **kwargs):
initval=initval[1] * np.sqrt(1 - initval[0] ** 2),
**kwargs,
)
norm = at.sqrt(1 - x1**2)
pm.Potential(f"__{name_y}_jacobian", at.log(norm))
norm = pt.sqrt(1 - x1**2)
pm.Potential(f"__{name_y}_jacobian", pt.log(norm))
return x1, pm.Deterministic(name_y, x2 * norm)
12 changes: 9 additions & 3 deletions src/pymc_ext/optim.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
__all__ = ["optimize"]

import pymc as pm
from aesara.graph.basic import graph_inputs
from aesara.tensor.var import TensorConstant, TensorVariable
from pytensor.graph.basic import graph_inputs
from pytensor.tensor.variable import TensorConstant, TensorVariable


def optimize(start=None, vars=None, **kwargs):
Expand All @@ -11,9 +11,15 @@ def optimize(start=None, vars=None, **kwargs):
if not isinstance(vars, (list, tuple)):
vars = [vars]

# In PyMC >= 5, model context is required to replace rvs with values
# https://github.com/pymc-devs/pymc/pull/6281
# https://www.pymc.io/projects/docs/en/stable/learn/core_notebooks/pymc_pytensor.html#pymc
model = kwargs.get("model")
model = pm.modelcontext(model)

# find_MAP only supports passing in members of free_RVs, so let's deal
# with that here...
vars = pm.aesaraf.rvs_to_value_vars(vars)
vars = model.replace_rvs_by_values(vars)
vars = [
v
for v in graph_inputs(vars)
Expand Down
25 changes: 19 additions & 6 deletions src/pymc_ext/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,26 @@


class Evaluator:
def __init__(self, outs, **kwargs):
"""Class to compile and evaluate components of a PyMC model
Args:
outs: The random variable, tensor, or list thereof to evaluate
model (Optional): PyMC model in which the variable are defined.
Tries to infer current model context if None.
**kwargs: All other kwargs are passed to pymc.pytensorf.compile_pymc.
"""

def __init__(self, outs, model=None, **kwargs):
# In PyMC >= 5, model context is required to replace rvs with values
# https://github.com/pymc-devs/pymc/pull/6281
# https://www.pymc.io/projects/docs/en/stable/learn/core_notebooks/pymc_pytensor.html#pymc
model = pm.modelcontext(None)
if isinstance(outs, (tuple, list)):
self.out_values = pm.aesaraf.rvs_to_value_vars(outs)
self.out_values = model.replace_rvs_by_values(outs)
else:
self.out_values = pm.aesaraf.rvs_to_value_vars([outs])[0]
self.out_values = model.replace_rvs_by_values([outs])[0]
self.in_values = pm.inputvars(self.out_values)
self.func = pm.aesaraf.compile_pymc(
self.func = pm.pytensorf.compile_pymc(
self.in_values, self.out_values, **kwargs
)

Expand All @@ -21,7 +34,7 @@ def __call__(self, point):


def eval_in_model(outs, point=None, model=None, seed=None, **kwargs):
"""Evaluate a Theano tensor or PyMC3 variable in a PyMC3 model
"""Evaluate a PyTensor tensor or PyMC variable in a PyMC model
This method builds a Theano function for evaluating a node in the graph
given the required parameters. This will also cache the compiled Theano
Expand All @@ -38,7 +51,7 @@ def eval_in_model(outs, point=None, model=None, seed=None, **kwargs):
if point is None:
model = pm.modelcontext(model)
point = model.initial_point(random_seed=seed)
return Evaluator(outs, **kwargs)(point)
return Evaluator(outs, model=model, **kwargs)(point)


def sample_inference_data(idata, size=1, random_seed=None, group="posterior"):
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
# ref: https://github.com/pymc-devs/pymc3/blob/master/pymc3/tests/conftest.py

import aesara
import pytensor
import pytest


@pytest.fixture(scope="package", autouse=True)
def theano_config():
flags = dict(compute_test_value="off")
config = aesara.configparser.change_flags(**flags)
config = pytensor.config.change_flags(**flags)
with config:
yield
2 changes: 1 addition & 1 deletion tests/optim_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import aesara.tensor as at
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytest

from pymc_ext.optim import optimize
Expand Down
29 changes: 28 additions & 1 deletion tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,33 @@ def test_eval_in_model(seed=123409):
x_val = np.random.randn(5, 3)
x_val2 = np.random.randn(5, 3)
with pm.Model():
x = pm.Normal("x", shape=x_val.shape, testval=x_val)
x = pm.Normal("x", shape=x_val.shape, initval=x_val)
assert np.allclose(eval_in_model(x), x_val)
assert np.allclose(eval_in_model(x, {"x": x_val2}), x_val2)


def test_eval_in_model_uniform(seed=123409):
# test_eval_in_model has unconstrained (-inf, inf) variables only
# Uniform has implicit transform in PyMC so check that this works too with
# eval_in_model
rng = np.random.default_rng(seed)
x_val = rng.uniform(size=(5, 3))
with pm.Model():
x = pm.Uniform("x", shape=x_val.shape, initval=x_val)

assert np.allclose(eval_in_model(x), x_val)


def test_eval_in_model_list(seed=123409):
# The utils.Evaluator class handles list of variables differently
# from single variables, so we test this here.
rng = np.random.default_rng(seed)
x_val = rng.uniform(size=(5, 3))
y_val = rng.standard_normal()

with pm.Model():
x = pm.Uniform("x", shape=x_val.shape, initval=x_val)
y = pm.Normal("y", initval=y_val)
x_eval, y_eval = eval_in_model([x, y])
assert np.allclose(x_eval, x_val)
assert np.allclose(y_eval, y_val)

0 comments on commit 6ea4f8e

Please sign in to comment.