We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vmap
lowered_fn
Here is a simple example of an associative function that scan_associative fails to handle because it assumes the associative op broadcasts.
scan_associative
The solution is to use jax.vmap to distributed elements in lowered_fn here rather than rely on broadcasting.
jax.vmap
import jax import numpy as np import tensorflow_probability.substrates.jax as tfp from jax import numpy as jnp def explicit_verify_associative(op, elems): output_1 = op(op(elems[0], elems[1]), elems[2]) output_2 = op(elems[0], op(elems[1], elems[2])) print(output_1, output_2) assert output_1 == output_2 def main(): elems = jax.random.normal(jax.random.PRNGKey(0), shape=(3,)) elem_shape = jax.tree.map(lambda x: np.shape(x[0]), elems) # () def per_elem_op(x) -> jax.Array: return jnp.sum(x) def associative_op(x, y): print(f"x.shape={np.shape(x)}, y.shape={np.shape(y)}") assert np.shape(x) == elem_shape assert np.shape(y) == elem_shape return per_elem_op(x) + per_elem_op(y) explicit_verify_associative(associative_op, elems) _ = tfp.math.scan_associative(associative_op, elems) if __name__ == '__main__': main()
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Here is a simple example of an associative function that
scan_associative
fails to handle because it assumes the associative op broadcasts.The solution is to use
jax.vmap
to distributed elements inlowered_fn
here rather than rely on broadcasting.MVCE
The text was updated successfully, but these errors were encountered: