diff --git a/ivy/data_classes/array/experimental/manipulation.py b/ivy/data_classes/array/experimental/manipulation.py index 013af82658b0a..6cdfeb97689f0 100644 --- a/ivy/data_classes/array/experimental/manipulation.py +++ b/ivy/data_classes/array/experimental/manipulation.py @@ -1134,6 +1134,62 @@ def take( self, indices, axis=axis, mode=mode, fill_value=fill_value, out=out ) + def unflatten( + self: ivy.Array, + /, + shape: Union[Tuple[int], ivy.Array, ivy.NativeArray], + dim: Optional[int] = 0, + *, + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ivy.Array instance method variant of ivy.unflatten. This method + simply wraps the function, and so the docstring for ivy.unflatten also + applies to this method with minimal changes. + + Parameters + ---------- + self + input array + shape + array indices. Must have an integer data type. + dim + axis over which to unflatten. If `axis` is negative, + the function must determine the axis along which to select values + by counting from the last dimension. + By default, the flattened input array is used. + out + optional output array, for writing the result to. It must + have a shape that the inputs broadcast to. + + Returns + ------- + ret + an array having the same data type as `x`. + The output array must have the same rank + (i.e., number of dimensions) as `x` and + must have the same shape as `x`, + except for the axis specified by `dim` + which is replaced with a tuple specified in `shape`. + + + Examples + -------- + With 'ivy.Array' input: + + >>> x = ivy.array([[1.2, 2.3, 3.4, 4.5], + [5.6, 6.7, 7.8, 8.9]]) + >>> shape = (2, 2) + >>> y = x.unflatten(shape=shape, dim=dim, out=y) + >>> print(y) + ivy.array([[[1.2, 2.3], [3.4, 4.5]], [[5.6, 6.7], [7.8, 8.9]]]) + """ + return ivy.unflatten( + self._data, + shape=shape, + dim=dim, + out=out, + ) + def trim_zeros( self: ivy.Array, /, diff --git a/ivy/data_classes/container/experimental/manipulation.py b/ivy/data_classes/container/experimental/manipulation.py index 20e72671746e1..fb293ce7ff254 100644 --- a/ivy/data_classes/container/experimental/manipulation.py +++ b/ivy/data_classes/container/experimental/manipulation.py @@ -4065,6 +4065,180 @@ def trim_zeros( """ return self._static_trim_zeros(self, trim=trim) + @staticmethod + def _static_unflatten( + x: Union[int, ivy.Array, ivy.NativeArray, ivy.Container], + /, + shape: Union[Tuple[int], ivy.Array, ivy.NativeArray, ivy.Container], + dim: Optional[Union[int, ivy.Container]] = 0, + *, + out: Optional[Union[ivy.Array, ivy.Container]] = None, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + ) -> ivy.Container: + """ivy.Container static method variant of ivy.unflatten. This method + simply wraps the function, and so the docstring for ivy.unflatten also + applies to this method with minimal changes. + + Parameters + ---------- + x + input array + shape + array indices. Must have an integer data type. + dim + axis over which to select values. If `axis` is negative, + the function must determine the axis along which to select values + by counting from the last dimension. + By default, the flattened input array is used. + out + optional output array, for writing the result to. It must + have a shape that the inputs broadcast to. + key_chains + The key-chains to apply or not apply the method to. + Default is ``None``. + to_apply + If True, the method will be applied to key_chains, + otherwise key_chains will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was + not applied. Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + + Returns + ------- + ret + an array having the same data type as `x`. + The output array must have the same rank + (i.e., number of dimensions) as `x` and + must have the same shape as `x`, + except for the axis specified by `axis` + whose size must equal the number of elements in `indices`. + + + Examples + -------- + With 'ivy.Container' input: + + >>> x = ivy.Container(a = ivy.array([[True, False, False, True], + [False, True, False, True]])), + ... b = ivy.array([[1.2, 2.3, 3.4, 4.5], + [5.6, 6.7, 7.8, 8.9]]), + ... c = ivy.array([[1, 2, 3, 4], + [5, 6, 7, 8]])) + >>> dim = 1 + >>> shape = (2, 2) + >>> y = ivy.Container._static_unflatten(x, shape=shape, dim=dim) + >>> print(y) + { + a: ivy.array([[[True, False], [False, True]], + [[False, True], [False, True]]]) + b: ivy.array([[[1.2, 2.3], [3.4, 4.5]], [[5.6, 6.7], [7.8, 8.9]]]) + c: ivy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + } + """ + return ContainerBase.cont_multi_map_in_function( + "unflatten", + x, + shape=shape, + dim=dim, + out=out, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + ) + + def unflatten( + self: ivy.Container, + /, + shape: Union[Tuple[int], ivy.Array, ivy.NativeArray, ivy.Container], + dim: Optional[Union[int, ivy.Container]] = 0, + *, + out: Optional[Union[ivy.Array, ivy.Container]] = None, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + ) -> ivy.Container: + """ivy.Container instance method variant of ivy.unflatten. This method + simply wraps the function, and so the docstring for ivy.unflatten also + applies to this method with minimal changes. + + Parameters + ---------- + self + input array + shape + array indices. Must have an integer data type. + dim + axis over which to unflatten. If `axis` is negative, + the function must determine the axis along which to select values + by counting from the last dimension. + By default, the flattened input array is used. + out + optional output array, for writing the result to. It must + have a shape that the inputs broadcast to. + key_chains + The key-chains to apply or not apply the method to. + Default is ``None``. + to_apply + If True, the method will be applied to key_chains, + otherwise key_chains will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was + not applied. Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + + Returns + ------- + ret + an array having the same data type as `x`. + The output array must have the same rank + (i.e., number of dimensions) as `x` and + must have the same shape as `x`, + except for the axis specified by `dim` + which is replaced with a tuple specified in `shape`. + + + Examples + -------- + With 'ivy.Container' input: + + >>> x = ivy.Container(a = ivy.array([[True, False, False, True], + [False, True, False, True]])), + ... b = ivy.array([[1.2, 2.3, 3.4, 4.5], + [5.6, 6.7, 7.8, 8.9]]), + ... c = ivy.array([[1, 2, 3, 4], + [5, 6, 7, 8]])) + >>> dim = 1 + >>> shape = (2, 2) + >>> y = x.unflatten(shape=shape, dim=dim) + >>> print(y) + { + a: ivy.array([[[True, False], [False, True]], + [[False, True], [False, True]]]) + b: ivy.array([[[1.2, 2.3], [3.4, 4.5]], [[5.6, 6.7], [7.8, 8.9]]]) + c: ivy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + } + """ + return self._static_unflatten( + self, + shape=shape, + dim=dim, + out=out, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + ) + def concat_from_sequence( self: ivy.Container, @@ -4130,11 +4304,11 @@ def concat_from_sequence( >>> print(z) { 'a': ivy.array([[[0, 1], - [3, 2]], - [[2, 3], - [1, 0]]]), + [3, 2]], + [[2, 3], + [1, 0]]]), 'b': ivy.array([[[4, 5], - [1, 0]]]) + [1, 0]]]) } """ new_input_sequence = ( diff --git a/ivy/functional/backends/jax/experimental/manipulation.py b/ivy/functional/backends/jax/experimental/manipulation.py index 36759c9dd1dfd..59ee341ab7870 100644 --- a/ivy/functional/backends/jax/experimental/manipulation.py +++ b/ivy/functional/backends/jax/experimental/manipulation.py @@ -475,8 +475,8 @@ def trim_zeros(a: JaxArray, /, *, trim: Optional[str] = "bf") -> JaxArray: def unflatten( x: JaxArray, /, - dim: int = 0, shape: Tuple[int] = None, + dim: int = 0, *, out: Optional[JaxArray] = None, order: Optional[str] = None, diff --git a/ivy/functional/backends/numpy/experimental/manipulation.py b/ivy/functional/backends/numpy/experimental/manipulation.py index a4a7c9ed12756..17b57a75f0f33 100644 --- a/ivy/functional/backends/numpy/experimental/manipulation.py +++ b/ivy/functional/backends/numpy/experimental/manipulation.py @@ -607,8 +607,8 @@ def put_along_axis( def unflatten( x: np.ndarray, /, - dim: int = 0, shape: Tuple[int] = None, + dim: Optional[int] = 0, *, out: Optional[np.ndarray] = None, order: Optional[str] = None, diff --git a/ivy/functional/backends/paddle/experimental/manipulation.py b/ivy/functional/backends/paddle/experimental/manipulation.py index 788a73d263c76..c6a482ee60b2d 100644 --- a/ivy/functional/backends/paddle/experimental/manipulation.py +++ b/ivy/functional/backends/paddle/experimental/manipulation.py @@ -908,12 +908,26 @@ def put_along_axis( ] +@with_supported_dtypes( + { + "2.6.0 and below": ( + "int32", + "int64", + "float64", + "complex128", + "float32", + "complex64", + "bool", + ) + }, + backend_version, +) @handle_out_argument def unflatten( x: paddle.Tensor, /, - dim: int = 0, shape: Tuple[int] = None, + dim: int = 0, *, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: diff --git a/ivy/functional/backends/tensorflow/experimental/manipulation.py b/ivy/functional/backends/tensorflow/experimental/manipulation.py index e03aebe12f749..e188f00a8aaae 100644 --- a/ivy/functional/backends/tensorflow/experimental/manipulation.py +++ b/ivy/functional/backends/tensorflow/experimental/manipulation.py @@ -571,8 +571,8 @@ def trim_zeros(a: tf.Tensor, /, *, trim: Optional[str] = "bf") -> tf.Tensor: def unflatten( x: tf.Tensor, /, - dim: int = 0, shape: Tuple[int] = None, + dim: Optional[int] = 0, *, out: Optional[tf.Tensor] = None, name: Optional[str] = None, diff --git a/ivy/functional/backends/torch/experimental/manipulation.py b/ivy/functional/backends/torch/experimental/manipulation.py index 485884c6511ed..97dd7183fa442 100644 --- a/ivy/functional/backends/torch/experimental/manipulation.py +++ b/ivy/functional/backends/torch/experimental/manipulation.py @@ -649,8 +649,8 @@ def trim_zeros(a: torch.Tensor, /, *, trim: Optional[str] = "bf") -> torch.Tenso def unflatten( x: torch.Tensor, /, - dim: int = 0, shape: Tuple[int] = None, + dim: Optional[int] = 0, *, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/ivy/functional/ivy/experimental/manipulation.py b/ivy/functional/ivy/experimental/manipulation.py index a4bf5fd435a47..d46b2bed23bc7 100644 --- a/ivy/functional/ivy/experimental/manipulation.py +++ b/ivy/functional/ivy/experimental/manipulation.py @@ -2884,9 +2884,9 @@ def trim_zeros( def unflatten( x: Union[ivy.Array, ivy.NativeArray], /, - *, dim: int, shape: Tuple[int], + *, out: Optional[ivy.Array] = None, ) -> ivy.Array: """Expand a dimension of the input tensor over multiple dimensions. @@ -2930,4 +2930,4 @@ def unflatten( >>> ivy.unflatten(torch.randn(5, 12, 3), dim=-2, shape=(2, 2, 3, 1, 1)).shape torch.Size([5, 2, 2, 3, 1, 1, 3]) """ - return current_backend(x).unflatten(x, dim=dim, shape=shape, out=out) + return ivy.current_backend(x).unflatten(x, dim=dim, shape=shape, out=out) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py index 5a77b6cfbbe6e..5b15e0ec81308 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py @@ -1788,8 +1788,8 @@ def test_torch_triu_indices( ), get_axis=helpers.get_axis( shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - max_size=1, - min_size=1, + max_size=0, + min_size=0, force_int=True, ), ) @@ -1804,10 +1804,9 @@ def test_torch_unflatten( shape, get_axis, ): - if type(get_axis) is not tuple: - axis = get_axis - else: - axis = 0 if get_axis is None else get_axis[0] + axis = get_axis + if type(axis) is tuple: + axis = 0 if not get_axis else get_axis[0] dtype, x = dtype_and_values def factorization(n): @@ -1835,7 +1834,8 @@ def get_factor(n): next = get_factor(n) factors.append(next) n //= next - + if len(factors) > 1: + factors.remove(1) return factors shape_ = ( diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py index 87b09dccecfd6..d88849c914204 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py @@ -2,6 +2,7 @@ from hypothesis import strategies as st, assume import hypothesis.extra.numpy as nph import numpy as np +import math # local import ivy @@ -1470,6 +1471,85 @@ def test_trim_zeros( ) +# unflatten +@handle_test( + fn_tree="functional.ivy.experimental.unflatten", + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + shape_key="shape", + ), + get_axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + max_size=0, + min_size=0, + force_int=True, + ), +) +def test_unflatten( + *, + dtype_and_values, + on_device, + fn_name, + test_flags, + backend_fw, + shape, + get_axis, +): + axis = get_axis + if type(axis) is tuple: + axis = 0 if not get_axis else get_axis[0] + dtype, x = dtype_and_values + + def factorization(n): + factors = [1] + + def get_factor(n): + x_fixed = 2 + cycle_size = 2 + x = 2 + factor = 1 if n % 2 else 2 + + while factor == 1: + for count in range(cycle_size): + if factor > 1: + break + x = (x * x + 1) % n + factor = math.gcd(x - x_fixed, n) + + cycle_size *= 2 + x_fixed = x + + return factor + + while n > 1: + next = get_factor(n) + factors.append(next) + n //= next + + if len(factors) > 1: + factors.remove(1) + return factors + + shape_ = ( + tuple(factorization(shape[axis])) + if tuple(factorization(shape[axis])) + else shape + ) + helpers.test_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_name=fn_name, + on_device=on_device, + test_values=False, + x=x[0], + shape=shape_, + dim=axis, + ) + + @handle_test( fn_tree="functional.ivy.experimental.unfold", dtype_values_axis=helpers.dtype_values_axis(