Skip to content

Commit

Permalink
revert unreviewed changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Madjid Chergui committed Sep 13, 2023
1 parent f554061 commit 5d844e2
Show file tree
Hide file tree
Showing 33 changed files with 968 additions and 735 deletions.
8 changes: 2 additions & 6 deletions ivy/functional/backends/jax/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
from typing import Optional, Tuple, Sequence, Union
import jax.numpy as jnp
import jax.scipy.linalg as jla
from . import backend_version

from ivy import with_supported_dtypes
from ivy.functional.backends.jax import JaxArray

import ivy
Expand Down Expand Up @@ -117,10 +114,9 @@ def eig(
return jnp.linalg.eig(x)


@with_supported_dtypes(
{"0.4.14 and below": {"float32", "float64", "complex"}}, backend_version
)
def eigvals(x: JaxArray, /) -> JaxArray:
if not ivy.dtype(x) in (ivy.float32, ivy.float64, ivy.complex64, ivy.complex128):
x = x.astype(jnp.float64)
return jnp.linalg.eigvals(x)


Expand Down
15 changes: 10 additions & 5 deletions ivy/functional/backends/jax/experimental/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def histogram(
return ret


@with_unsupported_dtypes({"0.4.14 and below": "complex"}, backend_version)
@with_unsupported_dtypes(
{"0.4.14 and below": ("complex64", "complex128")}, backend_version
)
def median(
input: JaxArray,
/,
Expand Down Expand Up @@ -289,10 +291,6 @@ def cov(
)


@with_unsupported_dtypes(
{"0.4.14 and below": ("bool", "float16", "int8", "int16", "complex", "uint8")},
backend_version,
)
def cummax(
x: JaxArray,
/,
Expand All @@ -303,6 +301,13 @@ def cummax(
dtype: Optional[jnp.dtype] = None,
out: Optional[JaxArray] = None,
) -> Tuple[JaxArray, JaxArray]:
if x.dtype in (jnp.bool_, jnp.float16):
x = x.astype(jnp.float64)
elif x.dtype in (jnp.int16, jnp.int8, jnp.uint8):
x = x.astype(jnp.int64)
elif x.dtype in (jnp.complex128, jnp.complex64):
x = jnp.real(x).astype(jnp.float64)

if exclusive or (reverse and exclusive):
if exclusive and reverse:
indices = __find_cummax_indices(jnp.flip(x, axis=axis), axis=axis)
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/numpy/experimental/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# local
from . import backend_version
from ivy.func_wrapper import with_unsupported_dtypes
from ivy import with_unsupported_dtypes


@with_unsupported_dtypes({"1.25.2 and below": ("complex",)}, backend_version)
Expand Down
10 changes: 7 additions & 3 deletions ivy/functional/backends/numpy/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def var(
# ------#


@with_unsupported_dtypes({"1.25.2 and below": ("bfloat16", "bool")}, backend_version)
@with_unsupported_dtypes({"1.25.2 and below": "bfloat16"}, backend_version)
def cumprod(
x: np.ndarray,
/,
Expand All @@ -183,7 +183,10 @@ def cumprod(
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if dtype is None:
dtype = _infer_dtype(x.dtype)
if x.dtype == "bool":
dtype = ivy.default_int_dtype(as_native=True)
else:
dtype = _infer_dtype(x.dtype)
if not (exclusive or reverse):
return np.cumprod(x, axis, dtype=dtype, out=out)
elif exclusive and reverse:
Expand All @@ -205,7 +208,6 @@ def cumprod(
cumprod.support_native_out = True


@with_unsupported_dtypes({"1.25.2 and below": ("bfloat16", "bool")}, backend_version)
def cumsum(
x: np.ndarray,
axis: int = 0,
Expand All @@ -216,6 +218,8 @@ def cumsum(
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if dtype is None:
if x.dtype == "bool":
dtype = ivy.default_int_dtype(as_native=True)
if ivy.is_int_dtype(x.dtype):
dtype = ivy.promote_types(x.dtype, ivy.default_int_dtype(as_native=True))
dtype = _infer_dtype(x.dtype)
Expand Down
84 changes: 47 additions & 37 deletions ivy/functional/backends/paddle/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

# local
import ivy.functional.backends.paddle as paddle_backend
from ivy.func_wrapper import with_supported_dtypes
from ivy.func_wrapper import with_unsupported_device_and_dtypes
from . import backend_version


@with_supported_dtypes(
{"2.5.1 and below": ("float16", "float32", "float64")},
backend_version,
@with_unsupported_device_and_dtypes(
{"2.5.1 and below": {"cpu": ("float16", "bfloat16")}}, backend_version
)
def logit(
x: paddle.Tensor,
Expand All @@ -21,54 +20,68 @@ def logit(
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out=None,
):
return paddle.logit(x, eps)
if x.dtype in [paddle.float32, paddle.float64]:
return paddle.logit(x, eps)
if eps is None:
nan = paddle_backend.squeeze(
paddle.to_tensor(float("nan"), dtype=x.dtype), axis=-1
)
x = paddle_backend.where(
paddle_backend.logical_or(
paddle_backend.greater(x, 1), paddle_backend.less(x, 0)
),
nan,
x,
)
else:
x = paddle_backend.minimum(paddle_backend.maximum(x, eps), 1 - eps)
return paddle_backend.log(
paddle_backend.divide(x, paddle_backend.subtract(1, x))
).cast(x.dtype)


@with_supported_dtypes(
{"2.5.1 and below": ("float16", "float32", "float64")},
backend_version,
)
def thresholded_relu(
x: paddle.Tensor,
/,
*,
threshold: Optional[Union[int, float]] = 0,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
return F.thresholded_relu(x, threshold=threshold)
if x.dtype in [paddle.float32, paddle.float64]:
return F.thresholded_relu(x, threshold=threshold)
return paddle_backend.where(paddle_backend.greater(x, threshold), x, 0).cast(
x.dtype
)


@with_supported_dtypes(
{"2.5.1 and below": ("float16", "float32", "float64", "complex")},
backend_version,
)
def relu6(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor:
if x.dtype in [paddle.float32, paddle.float64]:
return F.relu6(x)
if paddle.is_complex(x):
return paddle.complex(F.relu6(x.real()), F.relu6(x.imag()))
return F.relu6(x)
return F.relu6(x.cast("float32")).cast(x.dtype)


@with_supported_dtypes(
{"2.5.1 and below": ("float", "complex")},
backend_version,
@with_unsupported_device_and_dtypes(
{"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version
)
def logsigmoid(
x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None
input: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None
) -> paddle.Tensor:
if paddle.is_complex(x):
if input.dtype in [paddle.float32, paddle.float64]:
return F.log_sigmoid(input)
if paddle.is_complex(input):
return paddle_backend.log(
paddle_backend.divide(
1.0, (paddle_backend.add(1.0, paddle_backend.exp(-x)))
1.0, (paddle_backend.add(1.0, paddle_backend.exp(-input)))
)
)
return F.log_sigmoid(x)
return F.log_sigmoid(input.cast("float32")).cast(input.dtype)


@with_supported_dtypes(
{"2.5.1 and below": ("float", "complex")},
backend_version,
)
def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor:
if x.dtype in [paddle.float32, paddle.float64]:
return F.selu(x)
if paddle.is_complex(x):
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
Expand All @@ -81,26 +94,23 @@ def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.
),
)
return ret
return F.selu(x)
return F.selu(x.cast("float32")).cast(x.dtype)


@with_supported_dtypes(
{"2.5.1 and below": ("float", "complex")},
backend_version,
)
def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor:
if x.dtype in [paddle.float32, paddle.float64]:
return F.silu(x)
if paddle.is_complex(x):
return x * (1.0 / (1.0 + paddle_backend.exp(-x)))
return F.silu(x)
return F.silu(x.cast("float32")).cast(x.dtype)


@with_supported_dtypes(
{"2.5.1 and below": ("float", "complex")},
backend_version,
)
def elu(
x: paddle.Tensor, /, *, alpha: float = 1.0, out: Optional[paddle.Tensor] = None
) -> paddle.Tensor:
if x.dtype in [paddle.float32, paddle.float64]:
return F.elu(x, alpha=alpha)

if paddle.is_complex(x):
ret = (
paddle_backend.where(
Expand All @@ -110,4 +120,4 @@ def elu(
),
)
return ret
return F.elu(x, alpha=alpha)
return F.elu(x.cast("float32"), alpha).cast(x.dtype)
41 changes: 27 additions & 14 deletions ivy/functional/backends/paddle/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ivy.functional.backends.paddle.device import to_device
from ivy.func_wrapper import (
with_supported_dtypes,
with_unsupported_device_and_dtypes,
)


Expand Down Expand Up @@ -101,7 +102,7 @@ def tril_indices(


@with_supported_dtypes(
{"2.4.2 and below": ("float16", "float32", "float64", "int32", "int64")},
{"2.4.2 and below": ("float64", "float32", "int32", "int64")},
backend_version,
)
def unsorted_segment_min(
Expand Down Expand Up @@ -153,17 +154,6 @@ def blackman_window(
).cast(dtype)


@with_supported_dtypes(
{
"2.5.1 and below": (
"bool",
"float",
"int",
"complex",
)
},
backend_version,
)
def unsorted_segment_sum(
data: paddle.Tensor,
segment_ids: paddle.Tensor,
Expand All @@ -176,16 +166,39 @@ def unsorted_segment_sum(
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
data, segment_ids, num_segments
)

# Sum computation in paddle does not support int32, so needs to
# be converted to float32
needs_conv = False
if data.dtype == paddle.int32:
data = paddle.cast(data, "float32")
needs_conv = True

res = paddle.zeros((num_segments,) + tuple(data.shape[1:]), dtype=data.dtype)

for i in range(num_segments):
mask_index = segment_ids == i
if paddle.any(mask_index):
res[i] = paddle.sum(data[mask_index], axis=0)

# condition for converting float32 back to int32
if needs_conv is True:
res = paddle.cast(res, "int32")

return res


@with_supported_dtypes(
{"2.5.1 and below": ("int32", "int64", "float32", "float64")},
@with_unsupported_device_and_dtypes(
{
"2.5.1 and below": {
"cpu": (
"int8",
"int16",
"uint8",
"complex",
)
}
},
backend_version,
)
def trilu(
Expand Down
Loading

0 comments on commit 5d844e2

Please sign in to comment.