Skip to content

Commit

Permalink
fix (torch frontends)(pooling_functions.py): fixing the implementatio…
Browse files Browse the repository at this point in the history
…n of `torch.nn.functional.max_pool2d` to handle scalar kernel_size and dilation and convert it to a tuple.
  • Loading branch information
YushaArif99 committed Oct 10, 2024
1 parent ece779c commit 4ea15e9
Showing 1 changed file with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ def max_pool2d(

DIMS = 2
x_shape = list(input.shape[2:])
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(dilation, int):
dilation = (dilation, dilation)
new_kernel = [
kernel_size[i] + (kernel_size[i] - 1) * (dilation[i] - 1)
for i in range(DIMS)
Expand Down

0 comments on commit 4ea15e9

Please sign in to comment.