Skip to content

Commit

Permalink
fix: Fix frontend torch.Tensor.unfold method to output correct dimens…
Browse files Browse the repository at this point in the history
…ions
  • Loading branch information
hmahmood24 committed Sep 14, 2024
1 parent da92c0f commit 9321152
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,24 +795,35 @@ def new_empty(
return torch_frontend.tensor(_data)

def unfold(self, dimension, size, step):
# Ensure the dimension size is large enough for unfolding
if self.shape[dimension] < size:
raise ValueError(
f"Dimension size ({self.shape[dimension]}) is "
"smaller than the unfolding size ({size})."
)

slices = []
self_shape = tuple(self.shape)

# Create sliding window slices
for i in range(0, self_shape[dimension] - size + 1, step):
slicing = [slice(None)] * len(self.shape)
slicing[dimension] = slice(i, i + size)
slices.append(self.ivy_array[tuple(slicing)])
stacked = torch_frontend.stack(slices, dim=dimension)

# Stack the slices along a new dimension at 'dimension + 1'
stacked = torch_frontend.stack(slices, dim=dimension + 1)

# Reshape the tensor to insert a new window dimension
new_shape = list(self.shape)
num_slices = (self.shape[dimension] - size) // step + 1

# Replace size of the unfolded dimension with the number of slices
new_shape[dimension] = num_slices
if dimension == -1:
new_shape.insert(dimension, size)
else:
new_shape.insert(dimension + 1, size)
reshaped = stacked.reshape(new_shape)
dims = list(range(len(stacked.shape)))
dims[-2], dims[-1] = dims[-1], dims[-2]
return reshaped.permute(*dims)

# Append the window size at the end (correct behavior)
new_shape.append(size)
return stacked.reshape(new_shape)

def long(self, memory_format=None):
self.ivy_array = ivy.astype(self.ivy_array, ivy.int64, copy=False)
Expand Down

0 comments on commit 9321152

Please sign in to comment.