Skip to content

Commit

Permalink
Refactor torch frontend max pooling functions (#21814)
Browse files Browse the repository at this point in the history
Co-authored-by: @AnnaTz
  • Loading branch information
mohame54 authored Aug 18, 2023
1 parent 20870d9 commit 7f8e95d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 52 deletions.
49 changes: 20 additions & 29 deletions ivy/functional/frontends/torch/nn/functional/pooling_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,27 +95,26 @@ def avg_pool2d(


@to_ivy_arrays_and_back
def max_pool1d(input, kernel_size, stride=None, padding=0):
kernel_size = _broadcast_pooling_helper(kernel_size, "1d", name="kernel_size")
stride = _broadcast_pooling_helper(stride, "1d", name="stride")
padding = _broadcast_pooling_helper(padding, "1d", name="padding")
kernel_pads = list(zip(kernel_size, padding))

def max_pool1d(
input,
kernel_size,
stride=None,
padding=0,
ceil_mode=False,
dilation=1,
return_indices=False,
):
if stride is None:
stride = kernel_size
data_format = "NCW"

if not all([pad <= kernel / 2 for kernel, pad in kernel_pads]):
raise ValueError(
"pad should be smaller than or equal to half of kernel size, "
f"but got padding={padding}, kernel_size={kernel_size}. "
)
# figure out whether to apply padding
if all([pad == ivy.ceil((kernel - 1) / 2) for kernel, pad in kernel_pads]):
padding_str = "SAME"
else:
padding_str = "VALID"

return ivy.max_pool1d(
input, kernel_size, stride, padding_str, data_format=data_format
input,
kernel_size,
stride,
padding,
data_format=data_format,
dilation=dilation,
ceil_mode=ceil_mode,
)


Expand All @@ -130,14 +129,9 @@ def max_pool2d(
ceil_mode=False,
return_indices=False,
):
# ToDo: Add return_indices once superset in implemented
dim_check = False
if input.ndim == 3:
input = input.expand_dims()
dim_check = True
if not stride:
if stride is None:
stride = kernel_size
ret = ivy.max_pool2d(
return ivy.max_pool2d(
input,
kernel_size,
stride,
Expand All @@ -146,9 +140,6 @@ def max_pool2d(
dilation=dilation,
ceil_mode=ceil_mode,
)
if dim_check:
return ret.squeeze(0)
return ret


@to_ivy_arrays_and_back
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ def test_torch_avg_pool3d(
max_dims=3,
min_side=1,
max_side=3,
explicit_or_str_padding=False,
only_explicit_padding=True,
data_format="channel_first",
),
test_with_out=st.just(False),
)
Expand All @@ -216,23 +216,7 @@ def test_torch_max_pool1d(
on_device,
):
input_dtype, x, kernel_size, stride, padding = dtype_x_k_s

# Torch ground truth func expects input to be consistent
# with a channels first format i.e. NCW
x[0] = x[0].reshape((x[0].shape[0], x[0].shape[-1], x[0].shape[1]))
x_shape = [x[0].shape[2]]

# Torch ground truth func also takes padding input as an integer
# or a tuple of integers, not a string
padding = tuple(
[
ivy.functional.layers._handle_padding(
x_shape[i], stride[0], kernel_size[i], padding
)
for i in range(len(x_shape))
]
)

padding = (padding[0][0],)
helpers.test_frontend_function(
input_dtypes=input_dtype,
backend_to_test=backend_fw,
Expand All @@ -255,9 +239,10 @@ def test_torch_max_pool1d(
max_dims=4,
min_side=1,
max_side=4,
explicit_or_str_padding=True,
only_explicit_padding=True,
return_dilation=True,
).filter(lambda x: x[4] != "VALID" and x[4] != "SAME"),
data_format="channel_first",
),
test_with_out=st.just(False),
ceil_mode=st.just(True),
)
Expand All @@ -272,9 +257,6 @@ def test_torch_max_pool2d(
on_device,
):
dtype, x, kernel, stride, pad, dilation = x_k_s_p
# Torch ground truth func expects input to be consistent
# with a channels first format i.e. NCHW
x[0] = x[0].reshape((x[0].shape[0], x[0].shape[-1], *x[0].shape[1:-1]))
pad = (pad[0][0], pad[1][0])

helpers.test_frontend_function(
Expand Down

0 comments on commit 7f8e95d

Please sign in to comment.