Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: reformat sorting of sort function #23593

Closed
wants to merge 9 commits into from
82 changes: 82 additions & 0 deletions ivy/functional/backends/jax/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,88 @@ def sort(
stable: bool = True,
out: Optional[JaxArray] = None,
) -> JaxArray:
"""
Return a sorted copy of an array.

Parameters
----------
x : JaxArray
Input array to be sorted.
axis : int, optional
Axis along which to sort. If set to -1, the function must sort along
the last axis. Default: -1.
descending : bool, optional
If True, sort the elements in descending order; if False (default),
sort in ascending order.
stable : bool, optional
If True (default), use a stable sorting algorithm to maintain the
relative order of x values which compare as equal. If False, the
returned indices may or may not maintain the relative order of x
values which compare as equal (i.e., the relative order of x values
which compare as equal is implementation-dependent). Default: True.
out : Optional[JaxArray], optional
An optional output array, for writing the result to. It must have the
same shape as x. Defaults to None.

Returns
-------
JaxArray
An array with the same dtype and shape as x, with the elements sorted
along the given axis.

Examples
--------
With ivy.Array input:

>>> x = ivy.array([7, 8, 6])
>>> y = ivy.sort(x)
>>> print(y)
ivy.array([6, 7, 8])

Sorting along a specific axis:

>>> x = ivy.array([[[8.9,0], [19,5]],[[6,0.3], [19,0.5]]])
>>> y = ivy.sort(x, axis=1, descending=True, stable=False)
>>> print(y)
ivy.array([[[19. , 5. ],[ 8.9, 0. ]],[[19. , 0.5],[ 6. , 0.3]]])

Sorting in descending order:

>>> x = ivy.array([1.5, 3.2, 0.7, 2.5])
>>> y = ivy.zeros(5)
>>> ivy.sort(x, descending=True, stable=False, out=y)
>>> print(y)
ivy.array([3.2, 2.5, 1.5, 0.7])

Using an output array:

>>> x = ivy.array([[1.1, 2.2, 3.3],[-4.4, -5.5, -6.6]])
>>> ivy.sort(x, out=x)
>>> print(x)
ivy.array([[ 1.1, 2.2, 3.3],
[-6.6, -5.5, -4.4]
])

With ivy.Container input:

>>> x = ivy.Container(a=ivy.array([8, 6, 6]),b=ivy.array([[9, 0.7], [0.4, 0]]))
>>> y = ivy.sort(x, descending=True)
>>> print(y)
{
a: ivy.array([8, 6, 6]),
b: ivy.array([[9., 0.7], [0.4, 0.]])
}


>>> x = ivy.Container(a=ivy.array([3, 0.7, 1]),b=ivy.array([[4, 0.9], [0.6, 0.2]]))
>>> y = ivy.sort(x, descending=False, stable=False)
>>> print(y)
{
a: ivy.array([0.7, 1., 3.]),
b: ivy.array([[0.9, 4.], [0.2, 0.6]])
}

"""
kind = "stable" if stable else "quicksort"
ret = jnp.asarray(jnp.sort(x, axis=axis, kind=kind))
if descending:
Expand Down
Loading