Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delayed param #534

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

Delayed param #534

wants to merge 5 commits into from

Conversation

ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented Apr 23, 2021

Addresses #533

Group coded with @fritzo @eb8680 @fehiepsi

Yerdos Ordabayev added 2 commits April 23, 2021 09:50
@fritzo fritzo added the examples Examples and tutorials label Apr 23, 2021
Yerdos Ordabayev added 2 commits April 23, 2021 21:18

import funsor
from funsor.adam import Adam # noqa: F401
Copy link
Member Author

@ordabayevy ordabayevy Apr 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for compatibility with pyroapi

value, _ = PARAM_STORE[name]
if event_dim is None:
event_dim = value.dim()
output = funsor.Reals[value.shape[value.dim() - event_dim :]]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infer output when pyro.param was already defined elsewhere


def step(self, *args, **kwargs):
self.optim.num_steps = 1
return self.run(*args, **kwargs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for compatibility with SVI interface

Copy link
Member

@fritzo fritzo Apr 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, let's think about alternative workarounds... One issue here is that the Adam optimizer statistics would not be persisted across svi steps.

One option is simply to change pyroapi's SVI interface to look for either .run() or if missing fall back to .step(). Also I think it's more important to create a simple didactic example than to fastidiously conform to the pyroapi interface (since that interface hasn't seen much use).

for p in params:
p.grad = torch.zeros_like(p.grad)
return loss.item()
with funsor.terms.lazy:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lazy interpretation is needed here to make sure that funsor.Integrate is not eagerly expanded in Expectation

@ordabayevy
Copy link
Member Author

examples/minipyro.py is currently failing with jit. My guess is that jitting needs to be baked into funsor.adam.Adam, I will think more about this.

@ordabayevy ordabayevy added the WIP label Apr 24, 2021
@ordabayevy
Copy link
Member Author

Am I right that when using funsor.adam.Adam the function that needs to be jit traced is the loss function below (Subs funsor) @fritzo @eb8680 ? If yes, then it first needs to be converted to a function with positional arguments?

step_loss = loss(**{k: v[...] for k, v in params.items()}).data

@fritzo
Copy link
Member

fritzo commented Apr 25, 2021

...failing with jit. My guess is that jitting needs to be baked into funsor.adam.Adam

I think you're right, but let's discuss. That's a little different from Pyro where jit is baked into ELBO subclasses.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples Examples and tutorials WIP
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants