Skip to content

Commit

Permalink
Allow Minibatch logp on derived RVs
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 7, 2024
1 parent 623ca42 commit 6e0d317
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
5 changes: 3 additions & 2 deletions pymc/variational/minibatch_rv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from pytensor.graph import Apply, Op
from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable

from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper
from pymc.logprob.abstract import MeasurableOp, _logprob
from pymc.logprob.basic import logp


class MinibatchRandomVariable(MeasurableOp, Op):
Expand Down Expand Up @@ -99,4 +100,4 @@ def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> Tensor
def minibatch_rv_logprob(op, values, *inputs, **kwargs):
[value] = values
rv, *total_size = inputs
return _logprob_helper(rv, value, **kwargs) * get_scaling(total_size, value.shape)
return logp(rv, value, **kwargs) * get_scaling(total_size, value.shape)
10 changes: 10 additions & 0 deletions tests/variational/test_minibatch_rv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
import pytensor
import pytensor.tensor as pt
import pytest

from scipy import stats as st
Expand Down Expand Up @@ -186,3 +187,12 @@ def test_minibatch_parameter_and_value(self):
with m:
pm.set_data({"AD": rng.normal(size=1000)})
assert logp_fn(ip) != logp_fn(ip)

def test_derived_rv(self):
"""Test we can obtain a minibatch logp out of a derived RV."""
dist = pt.clip(pm.Normal.dist(0, 1, size=(1,)), -1, 1)
mb_dist = create_minibatch_rv(dist, total_size=(2,))
np.testing.assert_allclose(
pm.logp(mb_dist, -1).eval(),
pm.logp(dist, -1).eval() * 2,
)

0 comments on commit 6e0d317

Please sign in to comment.