-
-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expand models supported by automatic marginalization #300
base: main
Are you sure you want to change the base?
Expand models supported by automatic marginalization #300
Conversation
@@ -586,7 +825,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs | |||
) | |||
|
|||
marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) | |||
dependent_rvs_input_rvs = [ | |||
other_direct_rv_ancestors = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does other mean here in terms of the variable name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's the set of model RVs that are direct ancestors to the dependent rvs, excluding the marginal rv
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if there was some intuition for why some programs can be marginalized and some cannot. The distinction is a bit and for some cases it feels like we should be able to.
You can marginalize models as long as they don't mix dimensions of the marginal RV (and as long as we can be sure of that). So if |
2aa1e81
to
b0d0bce
Compare
This PR now depends on pymc-devs/pymc#7159 |
b0d0bce
to
ffb82c3
Compare
Aside from being re-based, are there any blockers to merging this? |
@ricardoV94 just had a baby so probably won't respond. Do you want to rebase and we can just merge for now? |
Nah I can still type now and then. PR was not yet done |
Congrats @ricardoV94 🥳 |
pytest is configured with the same behavior globally
This commit lifts the restriction that only Elemwise operations may link marginalized to dependent RVs. We map input dims to output dims, to assess whether an operation mixes information from different dims or not. Graphs where information is not mixed can be efficiently marginalized.
ffb82c3
to
378dbe4
Compare
9341906
to
01ed4c0
Compare
This PR allows more kinds of graphs to be marginalized. Previously, we were limiting it to Elemwise operations to ensure that information across batch dimensions was not mixed between the marginalized and dependent RVs, so as to generate an efficient logp expression that did not grow with number of batch dimensions, but was constant on the domain of the marginalized variables.
We still have the same constraint, but can now analyze it more carefully by propagating information about the batch dimensions of the marginalized RV across the intermediate operations. This allows operations like DimShuffle (transposition, expand_dims, squeeze), Blockwise, Reductions and, rather important, flavors of basic and advanced indexing like the following:
This should expand the range of models supported and open room for further expansions.
Some limitations were introduced to simplify the internal logic such as:
pm.Bernoulli(..., shape=(5, 1))
is not allowed, butpm.Bernoulli(..., shape=(5,))[:, None]
is. This simplifies the logic by only having one kind of possible dim in each axis.TODO:
subgrah_dims
utility directly