Skip to content

Commit

Permalink
chore: Added backends for ivy.complex and modified torch.complex fron…
Browse files Browse the repository at this point in the history
…tend to use ivy.complex instead of manually handling complex generation
  • Loading branch information
hmahmood24 committed Sep 11, 2024
1 parent 159afa0 commit 9c4c5b1
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 3 deletions.
4 changes: 2 additions & 2 deletions ivy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,8 +1018,8 @@ def _assert_array_significant_figures_formatting(sig_figs):
def vec_sig_fig(x, sig_fig=3):
if isinstance(x, np.bool_):
return x
if isinstance(x, complex):
return complex(x)
if isinstance(x, builtins.complex):
return builtins.complex(x)
if np.issubdtype(x.dtype, np.floating):
x_positive = np.where(np.isfinite(x) & (x != 0), np.abs(x), 10 ** (sig_fig - 1))
mags = 10 ** (sig_fig - 1 - np.floor(np.log10(x_positive)))
Expand Down
5 changes: 5 additions & 0 deletions ivy/functional/backends/jax/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from typing import Union, Optional, List, Sequence, Tuple

import jax
import jax.dlpack
import jax.numpy as jnp
import jax._src as _src
Expand Down Expand Up @@ -83,6 +84,10 @@ def asarray(
return jnp.copy(ret) if (dev(ret, as_native=True) != device or copy) else ret


def complex(real: JaxArray, imag: JaxArray, out: Optional[JaxArray] = None) -> JaxArray:
return jax.lax.complex(real, imag)


def empty(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
Expand Down
6 changes: 6 additions & 0 deletions ivy/functional/backends/numpy/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def arange(
return res


def complex(
real: np.ndarray, imag: np.ndarray, out: Optional[np.ndarray] = None
) -> np.ndarray:
return real + imag * 1j


@_asarray_to_native_arrays_and_back
@_asarray_infer_device
@_asarray_handle_nestable
Expand Down
8 changes: 8 additions & 0 deletions ivy/functional/backends/paddle/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ def arange(
return paddle.arange(start, stop, step).cast(dtype)


def complex(
real: paddle.Tensor,
imag: paddle.Tensor,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
return paddle.complex(real, imag)


@_asarray_to_native_arrays_and_back
@_asarray_infer_device
@_asarray_handle_nestable
Expand Down
6 changes: 6 additions & 0 deletions ivy/functional/backends/tensorflow/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def arange(
return tf.range(start, stop, delta=step, dtype=dtype)


def complex(
real: tf.Tensor, imag: tf.Tensor, out: Optional[tf.Tensor] = None
) -> tf.Tensor:
return tf.complex(real, imag)


@_asarray_to_native_arrays_and_back
@_asarray_infer_device
@_asarray_handle_nestable
Expand Down
6 changes: 6 additions & 0 deletions ivy/functional/backends/torch/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def arange(
arange.support_native_out = True


def complex(
real: torch.Tensor, imag: torch.Tensor, out: Optional[torch.Tensor] = None
) -> torch.Tensor:
return torch.complex(real, imag, out=out)


def _stack_tensors(x, dtype):
if isinstance(x, (list, tuple)) and len(x) != 0 and isinstance(x[0], (list, tuple)):
for i, item in enumerate(x):
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/frontends/torch/creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def complex(
out=None,
):
complex_dtype = ivy.complex64 if real.dtype != ivy.float64 else ivy.complex128
complex_array = real + imag * 1j
complex_array = ivy.complex(real, imag, out=out)
return complex_array.astype(complex_dtype, out=out)


Expand Down
68 changes: 68 additions & 0 deletions ivy/functional/ivy/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,74 @@ def arange(
)


@handle_backend_invalid
@handle_nestable
@handle_array_like_without_promotion
@handle_out_argument
@to_native_arrays_and_back
@outputs_to_ivy_arrays
@handle_array_function
@handle_device
def complex(
real: Union[ivy.Array, ivy.NativeArray],
imag: Union[ivy.Array, ivy.NativeArray],
*,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""Returns a complex array formed by combining a real and an imaginary
component element-wise. The real and imaginary components must have the
same shape.
Parameters
----------
real
An array representing the real part of the complex numbers.
imag
An array representing the imaginary part of the complex numbers.
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
Returns
-------
ret
A complex array where each element is formed by combining the corresponding
elements of `real` and `imag`.
This function conforms to the `Array API Standard
<https://data-apis.org/array-api/latest/>`_. This docstring is an extension of the
`docstring <https://data-apis.org/array-api/latest/
API_specification/generated/array_api.arange.html>`_
in the standard.
Both the description and the type hints above assumes an array input for simplicity,
but this function is *nestable*, and therefore also accepts :class:`ivy.Container`
instances in place of any of the arguments.
Examples
--------
>>> real = ivy.array([2.25, 3.25])
>>> imag = ivy.array([4.75, 5.75])
>>> x = ivy.complex(real, imag)
>>> print(x)
ivy.array([2.25+4.75j, 3.25+5.75j])
>>> real = ivy.array(1)
>>> imag = ivy.array(2)
>>> x = ivy.complex(real, imag)
>>> print(x)
ivy.array(1.+2.j)
>>> real = ivy.array([1, 2])
>>> imag = ivy.array([3, 4])
>>> x = ivy.complex(real, imag, step)
>>> print(x)
ivy.array([1.+3.j, 2.+4.j])
"""
return current_backend().complex(real, imag, out=out)


@temp_asarray_wrapper
@handle_backend_invalid
@handle_array_like_without_promotion
Expand Down

0 comments on commit 9c4c5b1

Please sign in to comment.