Skip to content

Commit

Permalink
feat: Update set.py (#23570)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Sandoval <[email protected]>
Co-authored-by: ivy-branch <[email protected]>
  • Loading branch information
3 people authored Oct 4, 2023
1 parent a7ea663 commit 0f33aeb
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 17 deletions.
12 changes: 9 additions & 3 deletions ivy/functional/backends/jax/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,19 @@ def unique_counts(
def unique_inverse(
x: JaxArray,
/,
*,
axis: Optional[int] = None,
) -> Tuple[JaxArray, JaxArray]:
Results = namedtuple("Results", ["values", "inverse_indices"])
values, inverse_indices = jnp.unique(x, return_inverse=True)
values, inverse_indices = jnp.unique(x, return_inverse=True, axis=axis)

nan_count = jnp.count_nonzero(jnp.isnan(x))
if nan_count > 1:
values = jnp.append(values, jnp.full(nan_count - 1, jnp.nan)).astype(x.dtype)
inverse_indices = jnp.reshape(inverse_indices, x.shape)
values = jnp.append(values, jnp.full(nan_count - 1, jnp.nan), axis=0).astype(
x.dtype
)
inverse_indices = jnp.reshape(inverse_indices, x.shape, axis=0)

return Results(values, inverse_indices)


Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/mxnet/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def unique_counts(


def unique_inverse(
x: Union[(None, mx.ndarray.NDArray)], /
x: Union[(None, mx.ndarray.NDArray)], /, *, axis: Optional[int] = None
) -> Tuple[(Union[(None, mx.ndarray.NDArray)], Union[(None, mx.ndarray.NDArray)])]:
raise IvyNotImplementedException()

Expand Down
10 changes: 7 additions & 3 deletions ivy/functional/backends/numpy/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,17 @@ def unique_counts(
def unique_inverse(
x: np.ndarray,
/,
*,
axis: Optional[int] = None,
) -> Tuple[np.ndarray, np.ndarray]:
Results = namedtuple("Results", ["values", "inverse_indices"])
values, inverse_indices = np.unique(x, return_inverse=True)
values, inverse_indices = np.unique(x, return_inverse=True, axis=axis)
nan_count = np.count_nonzero(np.isnan(x))
if nan_count > 1:
values = np.append(values, np.full(nan_count - 1, np.nan)).astype(x.dtype)
inverse_indices = inverse_indices.reshape(x.shape)
values = np.append(values, np.full(nan_count - 1, np.nan), axis=axis).astype(
x.dtype
)
inverse_indices = np.reshape(inverse_indices, x.shape, axis=0)
return Results(values, inverse_indices)


Expand Down
19 changes: 17 additions & 2 deletions ivy/functional/backends/paddle/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,23 @@ def unique_counts(x: paddle.Tensor, /) -> Tuple[paddle.Tensor, paddle.Tensor]:
@with_supported_dtypes(
{"2.5.1 and below": ("float32", "float64", "int32", "int64")}, backend_version
)
def unique_inverse(x: paddle.Tensor, /) -> Tuple[paddle.Tensor, paddle.Tensor]:
unique, inverse_val = paddle.unique(x, return_inverse=True)
def unique_inverse(
x: paddle.Tensor,
/,
*,
axis: Optional[int] = None,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
if x.dtype not in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]:
x, x_dtype = x.cast("float32"), x.dtype
else:
x.dtype

if axis is not None:
unique, inverse_val = paddle.unique(x, return_inverse=True, axis=axis)

if axis is None:
axis = 0

nan_idx = paddle.where(paddle.isnan(x) > 0)
nan_count = paddle.count_nonzero(nan_idx).numpy()[0]

Expand Down
6 changes: 6 additions & 0 deletions ivy/functional/backends/tensorflow/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,14 @@ def unique_counts(
def unique_inverse(
x: Union[tf.Tensor, tf.Variable],
/,
*,
axis: Optional[int] = None,
) -> Tuple[Union[tf.Tensor, tf.Variable], Union[tf.Tensor, tf.Variable]]:
Results = namedtuple("Results", ["values", "inverse_indices"])
if axis is None:
x = tf.reshape(x, shape=(-1,))
axis = 0

flat_tensor = tf.reshape(x, -1)
values = tf.unique(tf.sort(flat_tensor))[0]
values = tf.cast(values, dtype=x.dtype)
Expand Down
14 changes: 12 additions & 2 deletions ivy/functional/backends/torch/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,19 @@ def unique_counts(x: torch.Tensor, /) -> Tuple[torch.Tensor, torch.Tensor]:
},
backend_version,
)
def unique_inverse(x: torch.Tensor, /) -> Tuple[torch.Tensor, torch.Tensor]:
def unique_inverse(
x: torch.Tensor,
/,
*,
axis: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
Results = namedtuple("Results", ["values", "inverse_indices"])
values, inverse_indices = torch.unique(x, return_inverse=True)

if axis is None:
x = torch.flatten(x)
axis = 0

values, inverse_indices = torch.unique(x, return_inverse=True, axis=axis)
nan_idx = torch.isnan(x)
if nan_idx.any():
inverse_indices[nan_idx] = torch.where(torch.isnan(values))[0][0]
Expand Down
12 changes: 9 additions & 3 deletions ivy/functional/ivy/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def unique_all(
def unique_inverse(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
axis: Optional[int] = None,
) -> Tuple[Union[ivy.Array, ivy.NativeArray], Union[ivy.Array, ivy.NativeArray]]:
"""
Return the unique elements of an input array ``x``, and the indices from the set of
Expand Down Expand Up @@ -192,8 +194,12 @@ def unique_inverse(
Parameters
----------
x
input array. If ``x`` has more than one dimension, the function must flatten
``x`` and return the unique elements of the flattened array.
the arrray that will be inputted into the "unique_inverse" function
axis
the axis to apply unique on. If None, the unique elements of the flattened ``x``
are returned.
Returns
-------
Expand Down Expand Up @@ -253,7 +259,7 @@ def unique_inverse(
b: ivy.array([1, 0, 3, 1, 4, 2, 5])
}]
"""
return ivy.current_backend(x).unique_inverse(x)
return ivy.current_backend(x).unique_inverse(x, axis=axis)


@handle_exceptions
Expand Down
7 changes: 4 additions & 3 deletions ivy_tests/test_ivy/test_functional/test_core/test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,17 @@ def test_unique_counts(*, dtype_and_x, test_flags, backend_fw, fn_name, on_devic
test_with_out=st.just(False),
test_gradients=st.just(False),
)
def test_unique_inverse(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device):
dtype, x = dtype_and_x
assume(not np.any(np.isclose(x, 0.0)))
def test_unique_inverse(*, dtype_x_axis, test_flags, backend_fw, fn_name, on_device):
dtype, x, axis = dtype_x_axis
assume(not np.any(np.isclose(x, 0.0), axis=axis))

helpers.test_function(
input_dtypes=dtype,
test_flags=test_flags,
on_device=on_device,
backend_to_test=backend_fw,
fn_name=fn_name,
axis=axis,
x=x[0],
)

Expand Down

0 comments on commit 0f33aeb

Please sign in to comment.