Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand models supported by automatic marginalization #300

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jan 28, 2024

This PR allows more kinds of graphs to be marginalized. Previously, we were limiting it to Elemwise operations to ensure that information across batch dimensions was not mixed between the marginalized and dependent RVs, so as to generate an efficient logp expression that did not grow with number of batch dimensions, but was constant on the domain of the marginalized variables.

We still have the same constraint, but can now analyze it more carefully by propagating information about the batch dimensions of the marginalized RV across the intermediate operations. This allows operations like DimShuffle (transposition, expand_dims, squeeze), Blockwise, Reductions and, rather important, flavors of basic and advanced indexing like the following:

import pymc as pm
import pytensor.tensor as pt

from pymc_experimental import MarginalModel

with MarginalModel() as m:
    state = pm.Categorical("state", p=[0.1, 0.3, 0.6], shape=(4,))
    # Advanced indexing was not supported before
    # The indexed variable could be an RV as well!
    mu = pt.as_tensor([-10, 0, 10])[state]
    sigma = pm.HalfNormal("sigma")
    emission = pm.Normal("emission", mu, sigma, observed=[-9.0, -0.5, 1.0, 11.0])
    
m.marginalize(state)
m.point_logps()
# {'sigma': -0.73, 'emission': -10.52}

This should expand the range of models supported and open room for further expansions.

Some limitations were introduced to simplify the internal logic such as:

  1. Marginal RVs can't have explict broadcastable batch dims, but can be expanded explicitly or implicitly downstream. Thus, pm.Bernoulli(..., shape=(5, 1)) is not allowed, but pm.Bernoulli(..., shape=(5,))[:, None] is. This simplifies the logic by only having one kind of possible dim in each axis.
  2. Dependent RVs with batch dims beyond those introduced by the Marginal RV, must have such batch dims to the left

TODO:

  • Give more informative error for explicit broadcastable dims
  • Document internal logic more carefully
  • Allow dependent multivariate RVs (it should almost work out of the box, just need some tweaks and tests)
  • Test subgrah_dims utility directly
  • Test MvNormal k-clusters model that showed up in Discourse some time ago
  • Test dependent RVs batch to the left restriction
  • Allow a couple more simple Ops

@ricardoV94 ricardoV94 added the enhancements New feature or request label Jan 28, 2024
@@ -586,7 +825,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
)

marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
dependent_rvs_input_rvs = [
other_direct_rv_ancestors = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does other mean here in terms of the variable name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's the set of model RVs that are direct ancestors to the dependent rvs, excluding the marginal rv

Copy link
Contributor

@zaxtax zaxtax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if there was some intuition for why some programs can be marginalized and some cannot. The distinction is a bit and for some cases it feels like we should be able to.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Feb 1, 2024

It would be nice if there was some intuition for why some programs can be marginalized and some cannot. The distinction is a bit and for some cases it feels like we should be able to.

You can marginalize models as long as they don't mix dimensions of the marginal RV (and as long as we can be sure of that).

So if idx is the variable you are marginalizing, a direct dependent variable could use as a parameter idx + idx or idx.T + idx.T, but not idx + idx.T because it mixes distinct dimensions. Similarly, something like sum(idx) is not allowed.

@ricardoV94
Copy link
Member Author

This PR now depends on pymc-devs/pymc#7159

@ricardoV94 ricardoV94 force-pushed the extend_automatic_marginalization branch from b0d0bce to ffb82c3 Compare February 16, 2024 18:05
@zaxtax
Copy link
Contributor

zaxtax commented Jul 26, 2024

Aside from being re-based, are there any blockers to merging this?

@twiecki
Copy link
Member

twiecki commented Jul 26, 2024

Aside from being re-based, are there any blockers to merging this?

@ricardoV94 just had a baby so probably won't respond. Do you want to rebase and we can just merge for now?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jul 26, 2024

Nah I can still type now and then. PR was not yet done

@zaxtax
Copy link
Contributor

zaxtax commented Jul 26, 2024

Congrats @ricardoV94 🥳

pytest is configured with the same behavior globally
This commit lifts the restriction that only Elemwise operations may link marginalized to dependent RVs. We map input dims to output dims, to assess whether an operation mixes information from different dims or not. Graphs where information is not mixed can be efficiently marginalized.
@ricardoV94 ricardoV94 force-pushed the extend_automatic_marginalization branch from ffb82c3 to 378dbe4 Compare September 15, 2024 22:02
@ricardoV94 ricardoV94 force-pushed the extend_automatic_marginalization branch from 9341906 to 01ed4c0 Compare September 18, 2024 12:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request marginalization
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants