From 7f8e95dbf3d789c9f1c1d26d953a0413df387676 Mon Sep 17 00:00:00 2001 From: Mohammed Ayman Date: Fri, 18 Aug 2023 12:49:56 +0300 Subject: [PATCH] Refactor torch frontend max pooling functions (#21814) Co-authored-by: @AnnaTz --- .../torch/nn/functional/pooling_functions.py | 49 ++++++++----------- .../test_functional/test_pooling_functions.py | 28 ++--------- 2 files changed, 25 insertions(+), 52 deletions(-) diff --git a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py index 201b3016ac66c..59ef0d14e9387 100644 --- a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py @@ -95,27 +95,26 @@ def avg_pool2d( @to_ivy_arrays_and_back -def max_pool1d(input, kernel_size, stride=None, padding=0): - kernel_size = _broadcast_pooling_helper(kernel_size, "1d", name="kernel_size") - stride = _broadcast_pooling_helper(stride, "1d", name="stride") - padding = _broadcast_pooling_helper(padding, "1d", name="padding") - kernel_pads = list(zip(kernel_size, padding)) - +def max_pool1d( + input, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + dilation=1, + return_indices=False, +): + if stride is None: + stride = kernel_size data_format = "NCW" - - if not all([pad <= kernel / 2 for kernel, pad in kernel_pads]): - raise ValueError( - "pad should be smaller than or equal to half of kernel size, " - f"but got padding={padding}, kernel_size={kernel_size}. " - ) - # figure out whether to apply padding - if all([pad == ivy.ceil((kernel - 1) / 2) for kernel, pad in kernel_pads]): - padding_str = "SAME" - else: - padding_str = "VALID" - return ivy.max_pool1d( - input, kernel_size, stride, padding_str, data_format=data_format + input, + kernel_size, + stride, + padding, + data_format=data_format, + dilation=dilation, + ceil_mode=ceil_mode, ) @@ -130,14 +129,9 @@ def max_pool2d( ceil_mode=False, return_indices=False, ): - # ToDo: Add return_indices once superset in implemented - dim_check = False - if input.ndim == 3: - input = input.expand_dims() - dim_check = True - if not stride: + if stride is None: stride = kernel_size - ret = ivy.max_pool2d( + return ivy.max_pool2d( input, kernel_size, stride, @@ -146,9 +140,6 @@ def max_pool2d( dilation=dilation, ceil_mode=ceil_mode, ) - if dim_check: - return ret.squeeze(0) - return ret @to_ivy_arrays_and_back diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py index ce37a6be09e0b..2081a66eab98b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py @@ -201,8 +201,8 @@ def test_torch_avg_pool3d( max_dims=3, min_side=1, max_side=3, - explicit_or_str_padding=False, only_explicit_padding=True, + data_format="channel_first", ), test_with_out=st.just(False), ) @@ -216,23 +216,7 @@ def test_torch_max_pool1d( on_device, ): input_dtype, x, kernel_size, stride, padding = dtype_x_k_s - - # Torch ground truth func expects input to be consistent - # with a channels first format i.e. NCW - x[0] = x[0].reshape((x[0].shape[0], x[0].shape[-1], x[0].shape[1])) - x_shape = [x[0].shape[2]] - - # Torch ground truth func also takes padding input as an integer - # or a tuple of integers, not a string - padding = tuple( - [ - ivy.functional.layers._handle_padding( - x_shape[i], stride[0], kernel_size[i], padding - ) - for i in range(len(x_shape)) - ] - ) - + padding = (padding[0][0],) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -255,9 +239,10 @@ def test_torch_max_pool1d( max_dims=4, min_side=1, max_side=4, - explicit_or_str_padding=True, + only_explicit_padding=True, return_dilation=True, - ).filter(lambda x: x[4] != "VALID" and x[4] != "SAME"), + data_format="channel_first", + ), test_with_out=st.just(False), ceil_mode=st.just(True), ) @@ -272,9 +257,6 @@ def test_torch_max_pool2d( on_device, ): dtype, x, kernel, stride, pad, dilation = x_k_s_p - # Torch ground truth func expects input to be consistent - # with a channels first format i.e. NCHW - x[0] = x[0].reshape((x[0].shape[0], x[0].shape[-1], *x[0].shape[1:-1])) pad = (pad[0][0], pad[1][0]) helpers.test_frontend_function(