AssertionError when computing gradient of a function with reduce_window #24754
Unanswered
LeonardoTredese
asked this question in
Q&A
Replies: 1 comment 1 reply
-
Based on your error, you may noticed import jax
x = jax.numpy.ones((1, 28, 28))
def f(x):
return jax.nn.max_pool(x, (2, 2), (2, 2), padding="SAME").sum()
gradient = jax.grad(f)(x)
print(gradient)
|
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
When running the following example
I get the following assertion error:
I am using Python 3.11.6 and ubuntu 23.10.
jax==0.4.35
jax-cuda12-pjrt==0.4.35
jax-cuda12-plugin==0.4.35
jaxlib==0.4.35
I believe this is not an expected behaviour. Am I wrong? How can I get the gradients of this function?
Beta Was this translation helpful? Give feedback.
All reactions