You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
I implemented a softmax weighted pooling function, which scales each 2x2 patch in the image by softmax of the patch. It is basically a smooth version of maxpooling.
It is quite slow compared to maxpooling and I am wondering if it is possible to speed it up?
Is conv_general_dilated_patches the right approach here?
Is it possible to make conv_general_dilated_patches add a new dimension rather than stacking the patches in the channel dimension. This would help me avoid reshaping the array.
Is there a faster alternative to softmax? (for example a fastmath version or perhaps a cheaper function that can be used for this kind of gating)
@jax.jit
def sfm_pool(inputs):
x = jax.lax.conv_general_dilated_patches(inputs, (2,2), (2,2), "VALID", dimension_numbers=("NHWC", "OIHW", "NHWC"))
x_shape = x.shape
x = x.reshape(x_shape[0], x_shape[1], x_shape[2], x_shape[3]//4, 4)
weights = jax.nn.softmax(x, axis=4)
y = jnp.sum(weights*x, axis=4)
return y
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi,
I implemented a softmax weighted pooling function, which scales each 2x2 patch in the image by softmax of the patch. It is basically a smooth version of maxpooling.
It is quite slow compared to maxpooling and I am wondering if it is possible to speed it up?
conv_general_dilated_patches
the right approach here?conv_general_dilated_patches
add a new dimension rather than stacking the patches in the channel dimension. This would help me avoid reshaping the array.Beta Was this translation helpful? Give feedback.
All reactions