-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of rao-blackwellised particle filter (rbpf) and rbpf with optimal resampling. #323
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Hi @murphyk @slinderman , I was wondering if one would kindly take a look at this PR implementing the RBPF. Would you let me know if there is something that needs to be changed? I understand that implementing a SLDS model class goes beyond what was raised in issue #271 and might not be necessary, in which case I'd be happy to fix this and make a more minimal implementation. |
weights_t = weights_t / weights_t.sum() | ||
|
||
indices = jnp.arange(nparticles) | ||
pi = random.choice(key_reindex, indices, shape=(nparticles,), p=weights_t, replace=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use multinomial resampling?
For the resampling step, maybe it would be worth making a jax version of https://github.com/nchopin/particles/blob/master/particles/resampling.py#L541? |
LGTM. We can modify resampling code later (if necessary). |
Thanks, @kostastsa. One of the big items on my wishlist is to have a solid implementation of SLDS model variants and inference algorithms in dynamax. This is a great start. |
@murphyk Cool, I will check the multinomial resampling from Chopin and do a jax version. |
SLDS
based onSSM
. I have put everything in a folderdynamax/slds
.inference.py
file. Other inference methods for SLDS, (such as generalized pseudo Bayes etc) can be included in the same file.inference_test.py
which compares the new implementation to an older one referenced in implement rao-blackwellised particle filtering using dynamax and blackjax #271 which has been included indynamax/slds/mixture_kalman_filter.py
. Since the algorithm is stochastic and there is no exact baseline I wasn't sure how to test it, so I ended up doing anallclose
with a large tolerance between the two implementations.