Skip to content

Commit

Permalink
feat: made logit function support complex dtypes (ivy-llc#23213)
Browse files Browse the repository at this point in the history
  • Loading branch information
mohame54 committed Sep 8, 2023
1 parent 87f703b commit d54c273
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 20 deletions.
14 changes: 11 additions & 3 deletions ivy/data_classes/array/experimental/activations.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# global
import abc
from typing import Optional, Union
from typing import Optional, Union, Literal

# local
import ivy


class _ArrayWithActivationsExperimental(abc.ABC):
def logit(
self, /, *, eps: Optional[float] = None, out: Optional[ivy.Array] = None
self,
/,
*,
eps: Optional[float] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.logit. This method simply wraps the
Expand All @@ -23,6 +28,9 @@ def logit(
When eps is None the function outpus NaN where x < 0 or x > 1.
and inf or -inf where x = 1 or x = 0, respectively.
Otherwise if eps is defined, x is clamped to [eps, 1 - eps]
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
Optional output array.
Expand All @@ -43,7 +51,7 @@ def logit(
>>> print(z)
ivy.array([ 1.38629448, 1.38629448, -1.38629436])
"""
return ivy.logit(self, eps=eps, out=out)
return ivy.logit(self, eps=eps, complex_mode=complex_mode, out=out)

def thresholded_relu(
self: ivy.Array,
Expand Down
13 changes: 11 additions & 2 deletions ivy/data_classes/container/experimental/activations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global
from typing import Union, Optional, List, Dict
from typing import Union, Optional, List, Dict, Literal

# local
import ivy
Expand All @@ -13,6 +13,7 @@ def static_logit(
/,
*,
eps: Optional[Union[float, ivy.Container]] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Expand All @@ -28,6 +29,9 @@ def static_logit(
When eps is None the function outpus NaN where x < 0 or x > 1.
and inf or -inf where x = 1 or x = 0, respectively.
Otherwise if eps is defined, x is clamped to [eps, 1 - eps]
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
Optional output Contaner.
Expand Down Expand Up @@ -62,6 +66,7 @@ def static_logit(
"logit",
x,
eps=eps,
complex_mode=complex_mode,
out=out,
)

Expand All @@ -70,6 +75,7 @@ def logit(
/,
*,
eps: Optional[Union[float, ivy.Container]] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Expand All @@ -85,6 +91,9 @@ def logit(
When eps is None the function outpus NaN where x < 0 or x > 1.
and inf or -inf where x = 1 or x = 0, respectively.
Otherwise if eps is defined, x is clamped to [eps, 1 - eps]
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
Optional output Contaner.
Expand Down Expand Up @@ -115,7 +124,7 @@ def logit(
b: ivy.array([-1.38629436, 1.38629448, -1.38629436])
}
"""
return self.static_logit(self, eps=eps, out=out)
return self.static_logit(self, eps=eps, complex_mode=complex_mode, out=out)

@staticmethod
def static_thresholded_relu(
Expand Down
3 changes: 2 additions & 1 deletion ivy/functional/backends/jax/experimental/activations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, Union, Literal

# global
import jax
Expand All @@ -13,6 +13,7 @@ def logit(
/,
*,
eps: Optional[float] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[JaxArray] = None,
):
if eps is None:
Expand Down
3 changes: 2 additions & 1 deletion ivy/functional/backends/numpy/experimental/activations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, Union, Literal

# global
import numpy as np
Expand All @@ -15,6 +15,7 @@ def logit(
/,
*,
eps: Optional[float] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[np.ndarray] = None,
):
x_dtype = x.dtype
Expand Down
13 changes: 10 additions & 3 deletions ivy/functional/backends/paddle/experimental/activations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global
from typing import Optional, Union
from typing import Optional, Union, Literal
import paddle
import paddle.nn.functional as F

Expand All @@ -10,9 +10,16 @@


@with_unsupported_device_and_dtypes(
{"2.5.1 and below": {"cpu": ("float16",)}}, backend_version
{"2.5.1 and below": {"cpu": ("float16", "bfloat16")}}, backend_version
)
def logit(x: paddle.Tensor, /, *, eps: Optional[float] = None, out=None):
def logit(
x: paddle.Tensor,
/,
*,
eps: Optional[float] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out=None,
):
if x.dtype in [paddle.float32, paddle.float64]:
return paddle.logit(x, eps)
if eps is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, Union, Literal

# global
import tensorflow as tf
Expand All @@ -15,6 +15,7 @@ def logit(
/,
*,
eps: Optional[float] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[Tensor] = None,
) -> Tensor:
x_dtype = x.dtype
Expand Down
3 changes: 2 additions & 1 deletion ivy/functional/backends/torch/experimental/activations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, Union, Literal

# global
import torch
Expand All @@ -16,6 +16,7 @@ def logit(
/,
*,
eps: Optional[float] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.logit(x, eps=eps, out=out)
Expand Down
30 changes: 29 additions & 1 deletion ivy/functional/ivy/experimental/activations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global
from typing import Union, Optional
from typing import Union, Optional, Callable, Literal

# local
import ivy
Expand All @@ -14,21 +14,43 @@
inputs_to_ivy_arrays,
handle_device_shifting,
handle_backend_invalid,
handle_complex_input,
)


def _logit_jax_like(
x: Union[float, int, ivy.Array],
/,
*,
fn_original: Optional[Callable] = None,
eps: Optional[float] = None,
out: Optional[ivy.Array] = None,
):
real = ivy.real(x)
imag = ivy.imag(x)
if eps is None:
real = ivy.where(ivy.logical_or(real > 1, real < 0), ivy.nan, real)
else:
real = ivy.clip(real, eps, 1 - eps)
z = ivy.add(real, ivy.multiply(ivy.array(1j, dtype=x.dtype), imag))
z = ivy.log(z / (1 - z))
return z


@handle_exceptions
@handle_backend_invalid
@handle_nestable
@handle_array_like_without_promotion
@handle_out_argument
@to_native_arrays_and_back
@handle_device_shifting
@handle_complex_input
def logit(
x: Union[float, int, ivy.Array],
/,
*,
eps: Optional[float] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Expand All @@ -44,6 +66,9 @@ def logit(
When eps is None the function outpus NaN where x < 0 or x > 1.
and inf or -inf where x = 1 or x = 0, respectively.
Otherwise if eps is defined, x is clamped to [eps, 1 - eps]
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
Optional output array.
Expand All @@ -67,6 +92,9 @@ def logit(
return current_backend(x).logit(x, eps=eps, out=out)


logit.jax_like = _logit_jax_like


@handle_exceptions
@handle_nestable
@handle_array_like_without_promotion
Expand Down
27 changes: 22 additions & 5 deletions ivy/stateful/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,25 @@ def _forward(self, x):


class Logit(Module):
def __init__(self, eps=None):
"""Apply the LOGIT activation function."""
def __init__(
self,
eps=None,
complex_mode="jax",
):
"""
Apply the LOGIT activation function.
Parameters
----------
eps
The epsilon value for the logit formation. Default: ``None``.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
"""
Module.__init__(self)
self._eps = eps
self._complex_mode = complex_mode

def _forward(self, x):
"""
Expand All @@ -386,15 +401,17 @@ def _forward(self, x):
----------
x
Inputs to process *[batch_shape, d]*.
eps
The epsilon value for the logit formation. Default: ``None``.
Returns
-------
ret
The outputs following the LOGIT activation *[batch_shape, d]*
"""
return ivy.logit(x, eps=self._eps)
return ivy.logit(
x,
eps=self._eps,
complex_mode=self._complex_mode,
)


class PReLU(Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_elu(
@handle_test(
fn_tree="functional.ivy.experimental.logit",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
available_dtypes=helpers.get_dtypes("float_and_complex"),
large_abs_safety_factor=8,
small_abs_safety_factor=8,
safety_factor_scale="log",
Expand Down
2 changes: 1 addition & 1 deletion ivy_tests/test_ivy/test_stateful/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def test_log_softmax(
@handle_method(
method_tree="stateful.activations.Logit.__call__",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
available_dtypes=helpers.get_dtypes("float_and_complex"),
large_abs_safety_factor=8,
small_abs_safety_factor=8,
safety_factor_scale="log",
Expand Down

0 comments on commit d54c273

Please sign in to comment.