Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default moment for CustomDist provided with a dist function #6873

Merged
merged 41 commits into from
Nov 13, 2023

Conversation

aerubanov
Copy link
Contributor

@aerubanov aerubanov commented Aug 22, 2023

Add default implementation of moment for CustomDist with a dist function and close #6804


📚 Documentation preview 📚: https://pymc--6873.org.readthedocs.build/en/6873/

@aerubanov aerubanov marked this pull request as draft August 22, 2023 15:13
@codecov
Copy link

codecov bot commented Aug 22, 2023

Codecov Report

Merging #6873 (9b3c43d) into main (ec24ce6) will decrease coverage by 35.06%.
Report is 1 commits behind head on main.
The diff coverage is 23.07%.

Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main    #6873       +/-   ##
===========================================
- Coverage   92.19%   57.13%   -35.06%     
===========================================
  Files         101      101               
  Lines       16921    16964       +43     
===========================================
- Hits        15600     9693     -5907     
- Misses       1321     7271     +5950     
Files Coverage Δ
pymc/distributions/distribution.py 60.57% <23.07%> (-35.78%) ⬇️

... and 72 files with indirect coverage changes

@aerubanov
Copy link
Contributor Author

@ricardoV94 I added implementation which use graph rewriting to replace distributions by corresponding moments. Could you please take a look?

@aerubanov aerubanov marked this pull request as ready for review September 4, 2023 17:32
@aerubanov aerubanov marked this pull request as draft September 4, 2023 17:32
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat approach!

Just need to think about Ops with InnerGraphs

pymc/distributions/distribution.py Outdated Show resolved Hide resolved
@aerubanov
Copy link
Contributor Author

@ricardoV94 Do you know som simple example of symbolic dist with Scan op? I want to add test case with Scan, but I do not used it before, so when I try this:

def test_custom_dist_custom_moment_inner_graph(self):
        def dist(mu, size):
            ys, _ = pytensor.scan(
                fn=lambda x: pt.exp(pm.Normal.dist(x, 1, size=size)),
                sequences=[mu],
                outputs_info=[None], 
                name="ys",
                )
            return pt.sum(ys)
        
        with Model() as model:
            CustomDist("x", pt.ones(2), dist=dist)
        assert_moment_is_expected(model, 2)

I get ValueError: No update found for at least one RNG used in Scan Op Scan{ys, while_loop=False, inplace=none}. I guess I just use scan wrong

@ricardoV94
Copy link
Member

You need to return random updates from the scan. Check the utility collect_default_updates and the code example in the docstrings:

https://github.com/pymc-devs/pymc/blob/main/pymc/pytensorf.py#L1000

@aerubanov
Copy link
Contributor Author

@ricardoV94 yeah, it work, thank you!

@aerubanov
Copy link
Contributor Author

@ricardoV94 I added replacements inside inner graph, could you please take a look?

@aerubanov
Copy link
Contributor Author

@ricardoV94 could you please to check my comment above when you will have some time?

@aerubanov
Copy link
Contributor Author

@ricardoV94 just friendly reminder about example for OpFromGraph )

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 2, 2023

@aerubanov I couldn't come up with a working example of an OpFromGraph with a RandomVariable inside (should be fine outside), except other SymbolicRandomVariables. So I think we ignore for now .

However, this is a good reminder. Do you have a test showing that the moment works for a nested CustomDist or a SymbolicRandomVariable (like pm.Truncated) inside a CustomDist?

Something like:

import pymc as pm

def dist_fn(size):
    return pm.Truncated.dist(pm.Normal.dist(), -1, 1, size=size) + 5

x = pm.CustomDist.dist(dist=dist_fn)

@aerubanov
Copy link
Contributor Author

aerubanov commented Oct 2, 2023

@ricardoV94 Yeah, I need to add test case for pm.Truncated because I have not it now. And may be will need to add some changes in my implementation, because I did not take into account this case when was worked on my implementation. So thank you for pointing out on this!

@aerubanov aerubanov marked this pull request as ready for review October 3, 2023 15:05
new_node = rewrite_moment_scan_node(node)
for out1, out2 in zip(node.outputs, new_node.outputs):
fgraph.replace(out1, out2)
elif isinstance(node.op, (RandomVariable, SymbolicRandomVariable)):
Copy link
Member

@ricardoV94 ricardoV94 Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this branch should come first

@@ -687,6 +747,7 @@ def custom_dist_logcdf(op, value, size, *params, **kwargs):

@_moment.register(rv_type)
def custom_dist_get_moment(op, rv, size, *params):
params = [i for i in params if not isinstance(i, RandomGeneratorSharedVariable)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

params consists RandomGeneratorSharedVariables, which are not match with dist function signature. So I filter out it here, but may be there is a better way to do it

Copy link
Member

@ricardoV94 ricardoV94 Oct 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think below we do something like params[:len(dist_params)]? Does that work? Would be nice to add a comment to say we are excluding the shared RNGs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think params[:len(dist_params)] should work here, but do we have access to dist_params from this function? Looks like no.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case let's keep like you did, but perhaps use a helper function with a readable name? Could that helper be used below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added helper function to filter out shared RNGs

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CustomDist is getting more complex, we might want to move it into its own file later down the road.

pymc/distributions/distribution.py Outdated Show resolved Hide resolved
pymc/distributions/distribution.py Outdated Show resolved Hide resolved
Comment on lines 689 to 688
fgraph = get_rv_fgraph(dist, dist_params, size)
replace_moments = MomentRewrite()
Copy link
Member

@ricardoV94 ricardoV94 Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still worried that this will rewrite dist_params as well, and not just the CustomDist graph between dist and dist_params.

Would the following work? Clone the inner fgraph from rv.owner.op.fgraph. This graph doesn't have dist_params directly but dummy placeholders called NominalVariables. Apply the rewrite on this inner graph, and once you are done, replace the NominalVariables by the respective dist_params?

Copy link
Member

@ricardoV94 ricardoV94 Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually what I am suggesting is pretty similar to what you did with Scan, so you could perhaps reuse some of the same logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 When I trying to create FunctionGraph with NominalVariables I get error about missing input values - I think FunctionGraph constructor do not support dummy placeholders as input. But I`m try another approach:

def dist_moment(rv, size, *dist_params, dist):
    rv = dist(*dist_params, size=size)
    inputs, outputs = dist_params, [rv.owner.out]
    fgraph_topo = io_toposort(inputs, outputs)
    replace_with_moment = []
    to_replace_set = set()

    for nd in fgraph_topo:
          if nd not in to_replace_set and isinstance(
               nd.op, (RandomVariable, SymbolicRandomVariable)
          ):
                replace_with_moment.append(nd.out)
                to_replace_set.add(nd)
     givens = {}
     for item in replace_with_moment:
            givens[item] = moment(item)
     [out] = clone_replace(outputs, replace=givens)
     return out

Looks like it work but do not support Scan for now (but I can add it). Do you think such approach will be better?

Copy link
Member

@ricardoV94 ricardoV94 Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't have to recreate the FunctionGraph, you can just do fgraph.clone()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 I try this approach, but I need to create CustomDist op first because just output of dist function do not have fgraph attribute.

def dist_moment(rv, size, *dist_params):
            size = normalize_size_param(size)
            dummy_size_param = size.type()
            dummy_dist_params = [dist_param.type() for dist_param in dist_params]
            dummy_rv = dist(*dummy_dist_params, dummy_size_param)
            dummy_params = [dummy_size_param] + dummy_dist_params
            rv_type = type(
                class_name,
                (CustomSymbolicDistRV,),
                # If logp is not provided, we try to infer it from the dist graph
                dict(
                    inline_logprob=logp is None,
                ),
            )
            rv_op = rv_type(
                inputs=dummy_params,
                outputs=[dummy_rv],
                ndim_supp=ndim_supp,
            )
            fgraph = rv_op.fgraph.clone()
            replace_moments = MomentRewrite()
            replace_moments.rewrite(fgraph)
            for i, par in enumerate([size] + list(dist_params)):
                fgraph.replace(fgraph.inputs[i], par)
            [moment] = fgraph.outputs
            return moment

Do you think this way better? It is also do not work with Scan yet. I need to figure out why and I will move graph creation logic to helper function if we will decide to keep this way.

Copy link
Contributor Author

@aerubanov aerubanov Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For scan I got Assertion Error:

        # allocate storage for intermediate computation
        for node in order:
            for r in node.inputs:
                if r not in storage_map:
>                   assert isinstance(r, Constant)
E                   AssertionError

../../miniconda3/envs/pymc-dev/lib/python3.11/site-packages/pytensor/link/utils.py:135: AssertionError

It is hapens when I run TestCustomSymbolicDist::test_custom_dist_default_moment_inner_graph test case

Copy link
Contributor Author

@aerubanov aerubanov Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 If I create rv like this it`s do not have fgraph atribute:

rv = dist(*dist_params, size=size)
>>> rv
Exp.0
>>> rv.owner
Exp(normal_rv{0, (0, 0), floatX, False}.out)
>>> rv.owner.op
Elemwise(scalar_op=exp,inplace_pattern=<frozendict {}>)
>>> rv.owner.op.fgraph
*** AttributeError: 'Elemwise' object has no attribute 'fgraph'

I use lambda mu, sigma, size: pt.exp(pm.Normal.dist(mu, sigma, size=size)) for dist function here.

Copy link
Member

@ricardoV94 ricardoV94 Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, use the rv that is passed to the moment function directly (first argument), don't recreate it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohm yeah, my bad. Will try it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 No, I can not use rv that is passes as parameter:

>>> rv
*** Not yet returned!

So I tried to re-create it

@aerubanov
Copy link
Contributor Author

@ricardoV94 Could you please take a look on recent changes? Just friendly reminder)

@@ -622,13 +683,28 @@ def dist(
if logcdf is None:
logcdf = default_not_implemented(class_name, "logcdf")

def dist_moment(rv, size, *dist_params):
fgraph = rv.owner.op.fgraph.clone()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method can now be used for any OpFromGraph not just CustomDist ones, so I would move it to the MomentRewriter

pymc/distributions/distribution.py Outdated Show resolved Hide resolved
pymc/distributions/distribution.py Outdated Show resolved Hide resolved
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really great. Just some minor cleaning up suggestions left

Comment on lines 142 to 143
for out1, out2 in zip(node.outputs, new_node.outputs):
fgraph.replace(out1, out2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use replace_all?

pymc/distributions/distribution.py Outdated Show resolved Hide resolved
has_fallback=True,
ndim_supp=ndim_supp,
)
moment = dist_moment
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this if statement (see other comment about default)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 Hm, I think need this condition to avoid overriding moment provided by user, how can we avoid it without condition?

Copy link
Contributor Author

@aerubanov aerubanov Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More specifically, If I try to remove if statement TestCustomSymbolicDist::test_custom_methods fails

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two levels. We dispatch the general moment to the parent base class. Whenever a new subclass is created here and the user provided a moment we register it on the subclass (so it has preference). If the user didn't provide anything, we don't register and the parent class one will be used.

tests/distributions/test_distribution.py Outdated Show resolved Hide resolved
Comment on lines 88 to 89
def filter_RNGs(params):
return [p for p in params if not isinstance(p.type, (RandomType, RandomGeneratorType))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this ended up not being needed let's just inline it in the custom moment function

pymc/distributions/distribution.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

Great work @aerubanov!

@ricardoV94 ricardoV94 merged commit ad450a6 into pymc-devs:main Nov 13, 2023
20 of 22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement default moment function for CustomDist provided with a dist function
2 participants