Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use reverse=True keyword argument in lax.scan for smoothers #365

Merged
merged 2 commits into from
Jun 26, 2024

Conversation

edeno
Copy link
Contributor

@edeno edeno commented Jun 25, 2024

Fixes issue #364 by using the reverse=True keyword argument in lax.scan function.

Tests are passing locally except for the unscented kalman filter inference tests, but these were also failing for the original code as far as I can tell?

I also had to pin numpy < 2.0 because tensorflow_probability was failing (note that this is also a problem in the docs.

There could potentially be further improvement in eliminating unnecessary memory copies by array slicing but I think it would destroy some of the readability of the code and result in some computation overhead. For example, the stack operations (jnp.vstack([smoothed_probs, filtered_probs[-1]]) and slicing filtered_probs[:-1]) create copies. This isn't really a problem unless you have a large number of states:

    # Run the HMM smoother
    _, smoothed_probs = lax.scan(
        _step,
        filtered_probs[-1],
        (jnp.arange(num_timesteps - 1), filtered_probs[:-1], predicted_probs[1:]),
        reverse=True,
    )

    # Concatenate the arrays and return
    smoothed_probs = jnp.vstack([smoothed_probs, filtered_probs[-1]])

@edeno edeno changed the title Use reverse=True keyword argument in lax.scan Use reverse=True keyword argument in lax.scan for smoothers Jun 25, 2024
Copy link
Collaborator

@slinderman slinderman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks! Unfortunately TFP broke for numpy 2.0. I opened an issue about it last week, but they haven't responded yet (tensorflow/probability#1814).

@slinderman
Copy link
Collaborator

We can track the test failure here: #367

@slinderman slinderman merged commit 9b3fb2f into probml:main Jun 26, 2024
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants