-
trying to wrangle something i thought would work with basically i have a function that takes an input, downsamples it by some factor, runs some convolutions and then does spatial mean pooling. for the purpose of this example though we just do the downsampling and pooling. ( see this simple repro colab for the full code ) import haiku as hk
def downsample(x, factor):
# note: i get the sample problems with jax.image.resize
return hk.max_pool(x, window_shape=(1, factor, factor, 1),
strides=(1, factor, factor, 1), padding='VALID')
def global_spatial_mean_pooling(x):
return jnp.mean(x, axis=(1, 2))
def downsample_and_pool(x, factor):
x = downsample(x, factor)
# in real problem do some convolutions here
x = global_spatial_mean_pooling(x)
return x
x = jnp.ones((3, 8, 8, 5))
print(downsample_and_pool(x, 2).shape) # internally downsamples to (3, 4, 4, 5)
print(downsample_and_pool(x, 3).shape) # internally downsamples to (3, 2, 2, 5)
(3, 5)
(3, 5) i thought i'd be able to factors = jnp.array([1, 2])
pmap(downsample_and_pool, in_axes=(None, 0))(x, factors)
FilteredStackTrace: TypeError: reduce_window window_dimensions must have every element be an
integer type, got (<class 'int'>, <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>,
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>, <class 'int'>). i only want to run with a small distinct values for downsample = jit(downsample, static_argnums=1)
...
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected
cache-misses. Static argument (index 1) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
for function downsample is non-hashable. for the downsampling i don't need anything fancy so am wondering how far i'll get if i dig deeper in an attempt to make a super minimal version of the downsampling that is avoiding something? or is the inconsistent shape during the computation just not going to work and i'm fighting a losing battle? cheers, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Marking the In other words: this example isn't an instance of "single program, multiple data" parallelism—which the pmap/vmap parallelism model requires—because different |
Beta Was this translation helpful? Give feedback.
Marking the
factor
argument todownsample
as static seems like correct intuition, but the outerpmap
'ed function passes it something abstract by that point. At the end of the day, we cannot create a constant/static downsamplingfactor
(different values of which beget a differently compiledreduce_window
operation) from a dynamic argument. But we also cannotpmap
over several downsampling factors unless we make the factor a dynamic argument!In other words: this example isn't an instance of "single program, multiple data" parallelism—which the pmap/vmap parallelism model requires—because different
factor
s make different programs.