Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default moment for CustomDist provided with a dist function #6873

Merged
merged 41 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
030824e
add test for custom dist default moment
aerubanov Aug 22, 2023
cdbe6f8
add graph rewriting
aerubanov Sep 1, 2023
af82efc
add graph rewrite
aerubanov Sep 4, 2023
c12da7e
change test case
aerubanov Sep 4, 2023
b570fb6
remove commented code
aerubanov Sep 4, 2023
ea22848
add more test cases
aerubanov Sep 4, 2023
06c2646
replace Distribution by RandomVariable in node input check
aerubanov Sep 7, 2023
511f0f6
Update pymc/distributions/distribution.py
aerubanov Sep 9, 2023
9a4a801
add test case for dist with inner graph
aerubanov Sep 11, 2023
10ba63d
add extra test case and change apply method
aerubanov Sep 12, 2023
f078665
change inner graph moment replacement
aerubanov Sep 12, 2023
cee024d
remove comented code
aerubanov Sep 12, 2023
6e84d62
add initial implementation of scan op replacement
aerubanov Sep 22, 2023
aa8fce9
fix errors
aerubanov Sep 22, 2023
716f1f8
change test
aerubanov Sep 22, 2023
62b3c59
change normal ditribution by uniform in step func
aerubanov Sep 25, 2023
875e2ed
add test case for nested dist
aerubanov Oct 3, 2023
e1b0f66
remove unused and commented code
aerubanov Oct 5, 2023
3c30352
change conditions order
aerubanov Oct 5, 2023
c017195
remove comment
aerubanov Oct 5, 2023
6cb03f1
add helper function
aerubanov Oct 16, 2023
1b60994
add helper function for graph construction
aerubanov Oct 24, 2023
0656987
Update pymc/distributions/distribution.py
aerubanov Oct 30, 2023
155e44b
fix function name and imports
aerubanov Oct 31, 2023
097a057
transform scan rewrite function into method of MomentRewrite
aerubanov Oct 31, 2023
2647993
add check for no replacements needed
aerubanov Oct 31, 2023
a0ae812
fix method arguments
aerubanov Oct 31, 2023
ef7a7a0
change fgraph construction
aerubanov Nov 1, 2023
3004799
remove rv creation
aerubanov Nov 6, 2023
6b7c11b
remove commented code
aerubanov Nov 6, 2023
9db17f6
add comments
aerubanov Nov 6, 2023
b3548cb
Update tests/distributions/test_distribution.py
aerubanov Nov 10, 2023
20e066c
move dist_moment function outside of dist method
aerubanov Nov 10, 2023
db49e97
add new test case
aerubanov Nov 10, 2023
80c6b02
add changes from review
aerubanov Nov 13, 2023
ba00b38
remove filter_RNGs function
aerubanov Nov 13, 2023
a3c9f14
remove moment from dist method
aerubanov Nov 13, 2023
ed5a3c7
register moment fn
aerubanov Nov 13, 2023
9ec33c6
Update pymc/distributions/distribution.py
aerubanov Nov 13, 2023
d5899d4
fix moment arguments
aerubanov Nov 13, 2023
9b3c43d
remove separate var for op
aerubanov Nov 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 86 additions & 15 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@

from pytensor import tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import FunctionGraph, node_rewriter
from pytensor.graph.basic import Node, Variable
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out
from pytensor.graph import FunctionGraph, clone_replace, node_rewriter
from pytensor.graph.basic import Node, Variable, io_toposort
from pytensor.graph.features import ReplaceValidate
from pytensor.graph.rewriting.basic import GraphRewriter, in2out
from pytensor.graph.utils import MetaType
from pytensor.scan.op import Scan
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -83,6 +85,59 @@
PLATFORM = sys.platform


class MomentRewrite(GraphRewriter):
def rewrite_moment_scan_node(self, node):
if not isinstance(node.op, Scan):
return

Check warning on line 91 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L90-L91

Added lines #L90 - L91 were not covered by tests

node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
op = node.op

Check warning on line 94 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L93-L94

Added lines #L93 - L94 were not covered by tests

local_fgraph_topo = io_toposort(node_inputs, node_outputs)

Check warning on line 96 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L96

Added line #L96 was not covered by tests

replace_with_moment = []
to_replace_set = set()

Check warning on line 99 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L98-L99

Added lines #L98 - L99 were not covered by tests

for nd in local_fgraph_topo:
if nd not in to_replace_set and isinstance(

Check warning on line 102 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L101-L102

Added lines #L101 - L102 were not covered by tests
nd.op, (RandomVariable, SymbolicRandomVariable)
):
replace_with_moment.append(nd.out)
to_replace_set.add(nd)
givens = {}
if len(replace_with_moment) > 0:
for item in replace_with_moment:
givens[item] = moment(item)

Check warning on line 110 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L105-L110

Added lines #L105 - L110 were not covered by tests
else:
return
op_outs = clone_replace(node_outputs, replace=givens)

Check warning on line 113 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L112-L113

Added lines #L112 - L113 were not covered by tests

nwScan = Scan(

Check warning on line 115 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L115

Added line #L115 was not covered by tests
node_inputs,
op_outs,
op.info,
mode=op.mode,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
name=op.name,
allow_gc=op.allow_gc,
)
nw_node = nwScan(*(node.inputs), return_list=True)[0].owner
return nw_node

Check warning on line 126 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L125-L126

Added lines #L125 - L126 were not covered by tests

def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())

Check warning on line 129 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L129

Added line #L129 was not covered by tests

def apply(self, fgraph):
for node in fgraph.toposort():
if isinstance(node.op, (RandomVariable, SymbolicRandomVariable)):
fgraph.replace(node.out, moment(node.out))
elif isinstance(node.op, Scan):
new_node = self.rewrite_moment_scan_node(node)
if new_node is not None:
fgraph.replace_all(tuple(zip(node.outputs, new_node.outputs)))

Check warning on line 138 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L132-L138

Added lines #L132 - L138 were not covered by tests


class _Unpickling:
pass

Expand Down Expand Up @@ -601,6 +656,20 @@
return updates


@_moment.register(CustomSymbolicDistRV)
def dist_moment(op, rv, *args):
node = rv.owner
rv_out_idx = node.outputs.index(rv)

Check warning on line 662 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L661-L662

Added lines #L661 - L662 were not covered by tests

fgraph = op.fgraph.clone()
replace_moments = MomentRewrite()
replace_moments.rewrite(fgraph)

Check warning on line 666 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L664-L666

Added lines #L664 - L666 were not covered by tests
# Replace dummy inner inputs by outer inputs
fgraph.replace_all(tuple(zip(op.inner_inputs, args)), import_missing=True)
moment = fgraph.outputs[rv_out_idx]
return moment

Check warning on line 670 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L668-L670

Added lines #L668 - L670 were not covered by tests


class _CustomSymbolicDist(Distribution):
rv_type = CustomSymbolicDistRV

Expand All @@ -622,14 +691,6 @@
if logcdf is None:
logcdf = default_not_implemented(class_name, "logcdf")

if moment is None:
moment = functools.partial(
default_moment,
rv_name=class_name,
has_fallback=True,
ndim_supp=ndim_supp,
)

return super().dist(
dist_params,
class_name=class_name,
Expand Down Expand Up @@ -685,9 +746,19 @@
def custom_dist_logcdf(op, value, size, *params, **kwargs):
return logcdf(value, *params[: len(dist_params)])

@_moment.register(rv_type)
def custom_dist_get_moment(op, rv, size, *params):
return moment(rv, size, *params[: len(params)])
if moment is not None:

Check warning on line 749 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L749

Added line #L749 was not covered by tests

@_moment.register(rv_type)
def custom_dist_get_moment(op, rv, size, *params):
return moment(

Check warning on line 753 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L751-L753

Added lines #L751 - L753 were not covered by tests
rv,
size,
*[
p
for p in params
if not isinstance(p.type, (RandomType, RandomGeneratorType))
],
)

@_change_dist_size.register(rv_type)
def change_custom_symbolic_dist_size(op, rv, new_size, expand):
Expand Down
98 changes: 98 additions & 0 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,104 @@ def custom_dist(mu, sigma, size):
ip = m.initial_point()
np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip))

@pytest.mark.parametrize(
"dist_params, size, expected, dist_fn",
[
(
(5, 1),
None,
np.exp(5),
lambda mu, sigma, size: pt.exp(pm.Normal.dist(mu, sigma, size=size)),
),
(
(2, np.ones(5)),
None,
np.exp([2, 2, 2, 2, 2] + np.ones(5)),
lambda mu, sigma, size: pt.exp(
pm.Normal.dist(mu, sigma, size=size) + pt.ones(size)
),
),
(
(1, 2),
None,
np.sqrt(np.exp(1 + 0.5 * 2**2)),
lambda mu, sigma, size: pt.sqrt(pm.LogNormal.dist(mu, sigma, size=size)),
),
(
(4,),
(3,),
np.log([4, 4, 4]),
lambda nu, size: pt.log(pm.ChiSquared.dist(nu, size=size)),
),
(
(12, 1),
None,
12,
lambda mu1, sigma, size: pm.Normal.dist(mu1, sigma, size=size),
),
],
)
def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn):
with Model() as model:
CustomDist("x", *dist_params, dist=dist_fn, size=size)
assert_moment_is_expected(model, expected)

def test_custom_dist_default_moment_scan(self):
def scan_step(left, right):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
x = pm.Uniform.dist(left, right)
x_update = collect_default_updates([x])
return x, x_update

def dist(size):
xs, updates = scan(
fn=scan_step,
sequences=[
pt.as_tensor_variable(np.array([-4, -3])),
pt.as_tensor_variable(np.array([-2, -1])),
],
name="xs",
)
return xs

with Model() as model:
CustomDist("x", dist=dist)
assert_moment_is_expected(model, np.array([-3, -2]))

def test_custom_dist_default_moment_scan_recurring(self):
def scan_step(xtm1):
x = pm.Normal.dist(xtm1 + 1)
x_update = collect_default_updates([x])
return x, x_update

def dist(size):
xs, _ = scan(
fn=scan_step,
outputs_info=pt.as_tensor_variable(np.array([0])).astype(float),
n_steps=3,
name="xs",
)
return xs

with Model() as model:
CustomDist("x", dist=dist)
assert_moment_is_expected(model, np.array([[1], [2], [3]]))

@pytest.mark.parametrize(
"left, right, size, expected",
[
(-1, 1, None, 0 + 5),
(-3, -1, None, -2 + 5),
(-3, 1, (3,), np.array([-1 + 5, -1 + 5, -1 + 5])),
],
)
def test_custom_dist_default_moment_nested(self, left, right, size, expected):
def dist_fn(left, right, size):
return pm.Truncated.dist(pm.Normal.dist(0, 1), left, right, size=size) + 5

with Model() as model:
CustomDist("x", left, right, size=size, dist=dist_fn)
assert_moment_is_expected(model, expected)

def test_logcdf_inference(self):
def custom_dist(mu, sigma, size):
return pt.exp(pm.Normal.dist(mu, sigma, size=size))
Expand Down
Loading