Skip to content

Commit

Permalink
feat: added Ivy.unflatten (#28079)
Browse files Browse the repository at this point in the history
Co-authored-by: joaozenobio <[email protected]>
  • Loading branch information
Kacper-W-Kozdon and joaozenobio authored Feb 3, 2024
1 parent e786e88 commit 323b33d
Show file tree
Hide file tree
Showing 10 changed files with 342 additions and 18 deletions.
56 changes: 56 additions & 0 deletions ivy/data_classes/array/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,62 @@ def take(
self, indices, axis=axis, mode=mode, fill_value=fill_value, out=out
)

def unflatten(
self: ivy.Array,
/,
shape: Union[Tuple[int], ivy.Array, ivy.NativeArray],
dim: Optional[int] = 0,
*,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""ivy.Array instance method variant of ivy.unflatten. This method
simply wraps the function, and so the docstring for ivy.unflatten also
applies to this method with minimal changes.
Parameters
----------
self
input array
shape
array indices. Must have an integer data type.
dim
axis over which to unflatten. If `axis` is negative,
the function must determine the axis along which to select values
by counting from the last dimension.
By default, the flattened input array is used.
out
optional output array, for writing the result to. It must
have a shape that the inputs broadcast to.
Returns
-------
ret
an array having the same data type as `x`.
The output array must have the same rank
(i.e., number of dimensions) as `x` and
must have the same shape as `x`,
except for the axis specified by `dim`
which is replaced with a tuple specified in `shape`.
Examples
--------
With 'ivy.Array' input:
>>> x = ivy.array([[1.2, 2.3, 3.4, 4.5],
[5.6, 6.7, 7.8, 8.9]])
>>> shape = (2, 2)
>>> y = x.unflatten(shape=shape, dim=dim, out=y)
>>> print(y)
ivy.array([[[1.2, 2.3], [3.4, 4.5]], [[5.6, 6.7], [7.8, 8.9]]])
"""
return ivy.unflatten(
self._data,
shape=shape,
dim=dim,
out=out,
)

def trim_zeros(
self: ivy.Array,
/,
Expand Down
182 changes: 178 additions & 4 deletions ivy/data_classes/container/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4065,6 +4065,180 @@ def trim_zeros(
"""
return self._static_trim_zeros(self, trim=trim)

@staticmethod
def _static_unflatten(
x: Union[int, ivy.Array, ivy.NativeArray, ivy.Container],
/,
shape: Union[Tuple[int], ivy.Array, ivy.NativeArray, ivy.Container],
dim: Optional[Union[int, ivy.Container]] = 0,
*,
out: Optional[Union[ivy.Array, ivy.Container]] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
) -> ivy.Container:
"""ivy.Container static method variant of ivy.unflatten. This method
simply wraps the function, and so the docstring for ivy.unflatten also
applies to this method with minimal changes.
Parameters
----------
x
input array
shape
array indices. Must have an integer data type.
dim
axis over which to select values. If `axis` is negative,
the function must determine the axis along which to select values
by counting from the last dimension.
By default, the flattened input array is used.
out
optional output array, for writing the result to. It must
have a shape that the inputs broadcast to.
key_chains
The key-chains to apply or not apply the method to.
Default is ``None``.
to_apply
If True, the method will be applied to key_chains,
otherwise key_chains will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was
not applied. Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
Returns
-------
ret
an array having the same data type as `x`.
The output array must have the same rank
(i.e., number of dimensions) as `x` and
must have the same shape as `x`,
except for the axis specified by `axis`
whose size must equal the number of elements in `indices`.
Examples
--------
With 'ivy.Container' input:
>>> x = ivy.Container(a = ivy.array([[True, False, False, True],
[False, True, False, True]])),
... b = ivy.array([[1.2, 2.3, 3.4, 4.5],
[5.6, 6.7, 7.8, 8.9]]),
... c = ivy.array([[1, 2, 3, 4],
[5, 6, 7, 8]]))
>>> dim = 1
>>> shape = (2, 2)
>>> y = ivy.Container._static_unflatten(x, shape=shape, dim=dim)
>>> print(y)
{
a: ivy.array([[[True, False], [False, True]],
[[False, True], [False, True]]])
b: ivy.array([[[1.2, 2.3], [3.4, 4.5]], [[5.6, 6.7], [7.8, 8.9]]])
c: ivy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
}
"""
return ContainerBase.cont_multi_map_in_function(
"unflatten",
x,
shape=shape,
dim=dim,
out=out,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
)

def unflatten(
self: ivy.Container,
/,
shape: Union[Tuple[int], ivy.Array, ivy.NativeArray, ivy.Container],
dim: Optional[Union[int, ivy.Container]] = 0,
*,
out: Optional[Union[ivy.Array, ivy.Container]] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
) -> ivy.Container:
"""ivy.Container instance method variant of ivy.unflatten. This method
simply wraps the function, and so the docstring for ivy.unflatten also
applies to this method with minimal changes.
Parameters
----------
self
input array
shape
array indices. Must have an integer data type.
dim
axis over which to unflatten. If `axis` is negative,
the function must determine the axis along which to select values
by counting from the last dimension.
By default, the flattened input array is used.
out
optional output array, for writing the result to. It must
have a shape that the inputs broadcast to.
key_chains
The key-chains to apply or not apply the method to.
Default is ``None``.
to_apply
If True, the method will be applied to key_chains,
otherwise key_chains will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was
not applied. Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
Returns
-------
ret
an array having the same data type as `x`.
The output array must have the same rank
(i.e., number of dimensions) as `x` and
must have the same shape as `x`,
except for the axis specified by `dim`
which is replaced with a tuple specified in `shape`.
Examples
--------
With 'ivy.Container' input:
>>> x = ivy.Container(a = ivy.array([[True, False, False, True],
[False, True, False, True]])),
... b = ivy.array([[1.2, 2.3, 3.4, 4.5],
[5.6, 6.7, 7.8, 8.9]]),
... c = ivy.array([[1, 2, 3, 4],
[5, 6, 7, 8]]))
>>> dim = 1
>>> shape = (2, 2)
>>> y = x.unflatten(shape=shape, dim=dim)
>>> print(y)
{
a: ivy.array([[[True, False], [False, True]],
[[False, True], [False, True]]])
b: ivy.array([[[1.2, 2.3], [3.4, 4.5]], [[5.6, 6.7], [7.8, 8.9]]])
c: ivy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
}
"""
return self._static_unflatten(
self,
shape=shape,
dim=dim,
out=out,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
)


def concat_from_sequence(
self: ivy.Container,
Expand Down Expand Up @@ -4130,11 +4304,11 @@ def concat_from_sequence(
>>> print(z)
{
'a': ivy.array([[[0, 1],
[3, 2]],
[[2, 3],
[1, 0]]]),
[3, 2]],
[[2, 3],
[1, 0]]]),
'b': ivy.array([[[4, 5],
[1, 0]]])
[1, 0]]])
}
"""
new_input_sequence = (
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,8 @@ def trim_zeros(a: JaxArray, /, *, trim: Optional[str] = "bf") -> JaxArray:
def unflatten(
x: JaxArray,
/,
dim: int = 0,
shape: Tuple[int] = None,
dim: int = 0,
*,
out: Optional[JaxArray] = None,
order: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/numpy/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,8 @@ def put_along_axis(
def unflatten(
x: np.ndarray,
/,
dim: int = 0,
shape: Tuple[int] = None,
dim: Optional[int] = 0,
*,
out: Optional[np.ndarray] = None,
order: Optional[str] = None,
Expand Down
16 changes: 15 additions & 1 deletion ivy/functional/backends/paddle/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,12 +908,26 @@ def put_along_axis(
]


@with_supported_dtypes(
{
"2.6.0 and below": (
"int32",
"int64",
"float64",
"complex128",
"float32",
"complex64",
"bool",
)
},
backend_version,
)
@handle_out_argument
def unflatten(
x: paddle.Tensor,
/,
dim: int = 0,
shape: Tuple[int] = None,
dim: int = 0,
*,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,8 @@ def trim_zeros(a: tf.Tensor, /, *, trim: Optional[str] = "bf") -> tf.Tensor:
def unflatten(
x: tf.Tensor,
/,
dim: int = 0,
shape: Tuple[int] = None,
dim: Optional[int] = 0,
*,
out: Optional[tf.Tensor] = None,
name: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/torch/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,8 @@ def trim_zeros(a: torch.Tensor, /, *, trim: Optional[str] = "bf") -> torch.Tenso
def unflatten(
x: torch.Tensor,
/,
dim: int = 0,
shape: Tuple[int] = None,
dim: Optional[int] = 0,
*,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions ivy/functional/ivy/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2884,9 +2884,9 @@ def trim_zeros(
def unflatten(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
dim: int,
shape: Tuple[int],
*,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""Expand a dimension of the input tensor over multiple dimensions.
Expand Down Expand Up @@ -2930,4 +2930,4 @@ def unflatten(
>>> ivy.unflatten(torch.randn(5, 12, 3), dim=-2, shape=(2, 2, 3, 1, 1)).shape
torch.Size([5, 2, 2, 3, 1, 1, 3])
"""
return current_backend(x).unflatten(x, dim=dim, shape=shape, out=out)
return ivy.current_backend(x).unflatten(x, dim=dim, shape=shape, out=out)
Original file line number Diff line number Diff line change
Expand Up @@ -1788,8 +1788,8 @@ def test_torch_triu_indices(
),
get_axis=helpers.get_axis(
shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"),
max_size=1,
min_size=1,
max_size=0,
min_size=0,
force_int=True,
),
)
Expand All @@ -1804,10 +1804,9 @@ def test_torch_unflatten(
shape,
get_axis,
):
if type(get_axis) is not tuple:
axis = get_axis
else:
axis = 0 if get_axis is None else get_axis[0]
axis = get_axis
if type(axis) is tuple:
axis = 0 if not get_axis else get_axis[0]
dtype, x = dtype_and_values

def factorization(n):
Expand Down Expand Up @@ -1835,7 +1834,8 @@ def get_factor(n):
next = get_factor(n)
factors.append(next)
n //= next

if len(factors) > 1:
factors.remove(1)
return factors

shape_ = (
Expand Down
Loading

0 comments on commit 323b33d

Please sign in to comment.