Skip to content

Commit

Permalink
Allow for batched alpha in StickBreakingWeights (#6042)
Browse files Browse the repository at this point in the history
Co-authored-by: Sayam Kumar <[email protected]>
  • Loading branch information
purna135 and Sayam753 authored Aug 31, 2022
1 parent c8ce9c9 commit 0b191ad
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 16 deletions.
24 changes: 8 additions & 16 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,32 +2200,23 @@ def make_node(self, rng, size, dtype, alpha, K):
alpha = at.as_tensor_variable(alpha)
K = at.as_tensor_variable(intX(K))

if alpha.ndim > 0:
raise ValueError("The concentration parameter needs to be a scalar.")

if K.ndim > 0:
raise ValueError("K must be a scalar.")

return super().make_node(rng, size, dtype, alpha, K)

def _infer_shape(self, size, dist_params, param_shapes=None):
alpha, K = dist_params

size = tuple(size)

return size + (K + 1,)
def _supp_shape_from_params(self, dist_params, **kwargs):
K = dist_params[1]
return (K + 1,)

@classmethod
def rng_fn(cls, rng, alpha, K, size):
if K < 0:
raise ValueError("K needs to be positive.")

if size is None:
size = (K,)
elif isinstance(size, int):
size = (size,) + (K,)
else:
size = tuple(size) + (K,)
size = to_tuple(size) if size is not None else alpha.shape
size = size + (K,)
alpha = alpha[..., np.newaxis]

betas = rng.beta(1, alpha, size=size)

Expand Down Expand Up @@ -2294,9 +2285,10 @@ def dist(cls, alpha, K, *args, **kwargs):
return super().dist([alpha, K], **kwargs)

def moment(rv, size, alpha, K):
alpha = alpha[..., np.newaxis]
moment = (alpha / (1 + alpha)) ** at.arange(K)
moment *= 1 / (1 + alpha)
moment = at.concatenate([moment, [(alpha / (1 + alpha)) ** K]], axis=-1)
moment = at.concatenate([moment, (alpha / (1 + alpha)) ** K], axis=-1)
if not rv_size_is_none(size):
moment_size = at.concatenate(
[
Expand Down
31 changes: 31 additions & 0 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from aeppl.logprob import ParameterValueError
from aesara.tensor.random.utils import broadcast_params

from pymc.aesaraf import compile_pymc
from pymc.distributions.continuous import get_tau_sigma
from pymc.util import UNSET

Expand Down Expand Up @@ -953,6 +954,17 @@ def test_hierarchical_obs_logp():
assert not any(isinstance(o, RandomVariable) for o in ops)


@pytest.fixture(scope="module")
def stickbreakingweights_logpdf():
_value = at.vector()
_alpha = at.scalar()
_k = at.iscalar()
_logp = logp(StickBreakingWeights.dist(_alpha, _k), _value)
core_fn = compile_pymc([_value, _alpha, _k], _logp)

return np.vectorize(core_fn, signature="(n),(),()->()")


class TestMatchesScipy:
def test_uniform(self):
check_logp(
Expand Down Expand Up @@ -2318,6 +2330,25 @@ def test_stickbreakingweights_invalid(self):
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf
assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf

@pytest.mark.parametrize(
"alpha,K",
[
(np.array([0.5, 1.0, 2.0]), 3),
(np.arange(1, 7, dtype="float64").reshape(2, 3), 5),
],
)
def test_stickbreakingweights_vectorized(self, alpha, K, stickbreakingweights_logpdf):
value = pm.StickBreakingWeights.dist(alpha, K).eval()
with Model():
sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
pt = {"sbw": value}
assert_almost_equal(
pm.logp(sbw, value).eval(),
stickbreakingweights_logpdf(value, alpha, K),
decimal=select_by_precision(float64=6, float32=2),
err_msg=str(pt),
)

@aesara.config.change_flags(compute_test_value="raise")
def test_categorical_bounds(self):
with Model():
Expand Down
26 changes: 26 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,32 @@ def test_rice_moment(nu, sigma, size, expected):
fill_value=np.append((1 / 3) ** np.arange(5) * 2 / 3, (1 / 3) ** 5),
),
),
(
np.array([1, 3]),
11,
None,
np.array(
[
np.append((1 / 2) ** np.arange(11) * 1 / 2, (1 / 2) ** 11),
np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11),
]
),
),
(
np.array([1, 3, 5]),
9,
(5, 3),
np.full(
shape=(5, 3, 10),
fill_value=np.array(
[
np.append((1 / 2) ** np.arange(9) * 1 / 2, (1 / 2) ** 9),
np.append((3 / 4) ** np.arange(9) * 1 / 4, (3 / 4) ** 9),
np.append((5 / 6) ** np.arange(9) * 1 / 6, (5 / 6) ** 9),
]
),
),
),
],
)
def test_stickbreakingweights_moment(alpha, K, size, expected):
Expand Down
12 changes: 12 additions & 0 deletions pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,18 @@ def check_basic_properties(self):
assert np.all(draws <= 1)


class TestStickBreakingWeights_1D_alpha(BaseTestDistributionRandom):
pymc_dist = pm.StickBreakingWeights
pymc_dist_params = {"alpha": [1.0, 2.0, 3.0], "K": 19}
expected_rv_op_params = {"alpha": [1.0, 2.0, 3.0], "K": 19}
sizes_to_check = [None, (3,), (5, 3)]
sizes_expected = [(3, 20), (3, 20), (5, 3, 20)]
checks_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
]


class TestCategorical(BaseTestDistributionRandom):
pymc_dist = pm.Categorical
pymc_dist_params = {"p": np.array([0.28, 0.62, 0.10])}
Expand Down

0 comments on commit 0b191ad

Please sign in to comment.