Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of rao-blackwellised particle filter (rbpf) and rbpf with optimal resampling. #323

Merged
merged 13 commits into from
Nov 1, 2023

Conversation

kostastsa
Copy link
Contributor

@kostastsa kostastsa commented May 24, 2023

  • Addresses implement rao-blackwellised particle filtering using dynamax and blackjax #271.
  • To implement this I have created a new model class SLDS based on SSM. I have put everything in a folder dynamax/slds.
  • The filters are part of a inference.py file. Other inference methods for SLDS, (such as generalized pseudo Bayes etc) can be included in the same file.
  • Also, methods for learning SLDS can be build on this model.
  • I have included a file inference_test.py which compares the new implementation to an older one referenced in implement rao-blackwellised particle filtering using dynamax and blackjax #271 which has been included in dynamax/slds/mixture_kalman_filter.py. Since the algorithm is stochastic and there is no exact baseline I wasn't sure how to test it, so I ended up doing an allclose with a large tolerance between the two implementations.
  • From the plots and MSE it is seen that new implementation does what it is supposed to do.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@kostastsa
Copy link
Contributor Author

Hi @murphyk @slinderman , I was wondering if one would kindly take a look at this PR implementing the RBPF. Would you let me know if there is something that needs to be changed? I understand that implementing a SLDS model class goes beyond what was raised in issue #271 and might not be necessary, in which case I'd be happy to fix this and make a more minimal implementation.

dynamax/slds/mixture_kalman_filter.py Outdated Show resolved Hide resolved
weights_t = weights_t / weights_t.sum()

indices = jnp.arange(nparticles)
pi = random.choice(key_reindex, indices, shape=(nparticles,), p=weights_t, replace=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use multinomial resampling?

dynamax/ssm.py Outdated Show resolved Hide resolved
@murphyk
Copy link
Member

murphyk commented Oct 26, 2023

For the resampling step, maybe it would be worth making a jax version of https://github.com/nchopin/particles/blob/master/particles/resampling.py#L541?

@murphyk
Copy link
Member

murphyk commented Nov 1, 2023

LGTM. We can modify resampling code later (if necessary).

@murphyk murphyk merged commit a1cb385 into probml:main Nov 1, 2023
2 checks passed
@slinderman
Copy link
Collaborator

Thanks, @kostastsa. One of the big items on my wishlist is to have a solid implementation of SLDS model variants and inference algorithms in dynamax. This is a great start.

@kostastsa
Copy link
Contributor Author

@murphyk Cool, I will check the multinomial resampling from Chopin and do a jax version.
@slinderman That's great to hear! I would be super interested in being part of this effort. I can start working on this soon, since I will also need it for my own research.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants