Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tfp.math.scan_associative doesn't work for all associative functions (it should be using vmap for lowered_fn) #1812

Open
Joshuaalbert opened this issue Jun 4, 2024 · 0 comments

Comments

@Joshuaalbert
Copy link
Contributor

Here is a simple example of an associative function that scan_associative fails to handle because it assumes the associative op broadcasts.

The solution is to use jax.vmap to distributed elements in lowered_fn here rather than rely on broadcasting.

MVCE

import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp


def explicit_verify_associative(op, elems):
    output_1 = op(op(elems[0], elems[1]), elems[2])
    output_2 = op(elems[0], op(elems[1], elems[2]))
    print(output_1, output_2)
    assert output_1 == output_2


def main():
    elems = jax.random.normal(jax.random.PRNGKey(0), shape=(3,))

    elem_shape = jax.tree.map(lambda x: np.shape(x[0]), elems)  # ()

    def per_elem_op(x) -> jax.Array:
        return jnp.sum(x)

    def associative_op(x, y):
        print(f"x.shape={np.shape(x)}, y.shape={np.shape(y)}")
        assert np.shape(x) == elem_shape
        assert np.shape(y) == elem_shape
        return per_elem_op(x) + per_elem_op(y)

    explicit_verify_associative(associative_op, elems)

    _ = tfp.math.scan_associative(associative_op, elems)


if __name__ == '__main__':
    main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant