diff --git a/ivy/functional/backends/jax/experimental/linear_algebra.py b/ivy/functional/backends/jax/experimental/linear_algebra.py index c1758b3a9bb21..897254757e7ba 100644 --- a/ivy/functional/backends/jax/experimental/linear_algebra.py +++ b/ivy/functional/backends/jax/experimental/linear_algebra.py @@ -2,9 +2,6 @@ from typing import Optional, Tuple, Sequence, Union import jax.numpy as jnp import jax.scipy.linalg as jla -from . import backend_version - -from ivy import with_supported_dtypes from ivy.functional.backends.jax import JaxArray import ivy @@ -117,10 +114,9 @@ def eig( return jnp.linalg.eig(x) -@with_supported_dtypes( - {"0.4.14 and below": {"float32", "float64", "complex"}}, backend_version -) def eigvals(x: JaxArray, /) -> JaxArray: + if not ivy.dtype(x) in (ivy.float32, ivy.float64, ivy.complex64, ivy.complex128): + x = x.astype(jnp.float64) return jnp.linalg.eigvals(x) diff --git a/ivy/functional/backends/jax/experimental/statistical.py b/ivy/functional/backends/jax/experimental/statistical.py index 43d85be81ca3a..9d35e51ae6017 100644 --- a/ivy/functional/backends/jax/experimental/statistical.py +++ b/ivy/functional/backends/jax/experimental/statistical.py @@ -120,7 +120,9 @@ def histogram( return ret -@with_unsupported_dtypes({"0.4.14 and below": "complex"}, backend_version) +@with_unsupported_dtypes( + {"0.4.14 and below": ("complex64", "complex128")}, backend_version +) def median( input: JaxArray, /, @@ -289,10 +291,6 @@ def cov( ) -@with_unsupported_dtypes( - {"0.4.14 and below": ("bool", "float16", "int8", "int16", "complex", "uint8")}, - backend_version, -) def cummax( x: JaxArray, /, @@ -303,6 +301,13 @@ def cummax( dtype: Optional[jnp.dtype] = None, out: Optional[JaxArray] = None, ) -> Tuple[JaxArray, JaxArray]: + if x.dtype in (jnp.bool_, jnp.float16): + x = x.astype(jnp.float64) + elif x.dtype in (jnp.int16, jnp.int8, jnp.uint8): + x = x.astype(jnp.int64) + elif x.dtype in (jnp.complex128, jnp.complex64): + x = jnp.real(x).astype(jnp.float64) + if exclusive or (reverse and exclusive): if exclusive and reverse: indices = __find_cummax_indices(jnp.flip(x, axis=axis), axis=axis) diff --git a/ivy/functional/backends/numpy/experimental/general.py b/ivy/functional/backends/numpy/experimental/general.py index 37bd5d3b8a4e1..50bd626ccb591 100644 --- a/ivy/functional/backends/numpy/experimental/general.py +++ b/ivy/functional/backends/numpy/experimental/general.py @@ -4,7 +4,7 @@ # local from . import backend_version -from ivy.func_wrapper import with_unsupported_dtypes +from ivy import with_unsupported_dtypes @with_unsupported_dtypes({"1.25.2 and below": ("complex",)}, backend_version) diff --git a/ivy/functional/backends/numpy/statistical.py b/ivy/functional/backends/numpy/statistical.py index b07f35c4224a0..419ef9355d336 100644 --- a/ivy/functional/backends/numpy/statistical.py +++ b/ivy/functional/backends/numpy/statistical.py @@ -171,7 +171,7 @@ def var( # ------# -@with_unsupported_dtypes({"1.25.2 and below": ("bfloat16", "bool")}, backend_version) +@with_unsupported_dtypes({"1.25.2 and below": "bfloat16"}, backend_version) def cumprod( x: np.ndarray, /, @@ -183,7 +183,10 @@ def cumprod( out: Optional[np.ndarray] = None, ) -> np.ndarray: if dtype is None: - dtype = _infer_dtype(x.dtype) + if x.dtype == "bool": + dtype = ivy.default_int_dtype(as_native=True) + else: + dtype = _infer_dtype(x.dtype) if not (exclusive or reverse): return np.cumprod(x, axis, dtype=dtype, out=out) elif exclusive and reverse: @@ -205,7 +208,6 @@ def cumprod( cumprod.support_native_out = True -@with_unsupported_dtypes({"1.25.2 and below": ("bfloat16", "bool")}, backend_version) def cumsum( x: np.ndarray, axis: int = 0, @@ -216,6 +218,8 @@ def cumsum( out: Optional[np.ndarray] = None, ) -> np.ndarray: if dtype is None: + if x.dtype == "bool": + dtype = ivy.default_int_dtype(as_native=True) if ivy.is_int_dtype(x.dtype): dtype = ivy.promote_types(x.dtype, ivy.default_int_dtype(as_native=True)) dtype = _infer_dtype(x.dtype) diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index 5fa5162fb795c..33f1d80bbf8ef 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -5,13 +5,12 @@ # local import ivy.functional.backends.paddle as paddle_backend -from ivy.func_wrapper import with_supported_dtypes +from ivy.func_wrapper import with_unsupported_device_and_dtypes from . import backend_version -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64")}, - backend_version, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("float16", "bfloat16")}}, backend_version ) def logit( x: paddle.Tensor, @@ -21,13 +20,26 @@ def logit( complex_mode: Literal["split", "magnitude", "jax"] = "jax", out=None, ): - return paddle.logit(x, eps) + if x.dtype in [paddle.float32, paddle.float64]: + return paddle.logit(x, eps) + if eps is None: + nan = paddle_backend.squeeze( + paddle.to_tensor(float("nan"), dtype=x.dtype), axis=-1 + ) + x = paddle_backend.where( + paddle_backend.logical_or( + paddle_backend.greater(x, 1), paddle_backend.less(x, 0) + ), + nan, + x, + ) + else: + x = paddle_backend.minimum(paddle_backend.maximum(x, eps), 1 - eps) + return paddle_backend.log( + paddle_backend.divide(x, paddle_backend.subtract(1, x)) + ).cast(x.dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64")}, - backend_version, -) def thresholded_relu( x: paddle.Tensor, /, @@ -35,40 +47,41 @@ def thresholded_relu( threshold: Optional[Union[int, float]] = 0, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: - return F.thresholded_relu(x, threshold=threshold) + if x.dtype in [paddle.float32, paddle.float64]: + return F.thresholded_relu(x, threshold=threshold) + return paddle_backend.where(paddle_backend.greater(x, threshold), x, 0).cast( + x.dtype + ) -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "complex")}, - backend_version, -) def relu6(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: + if x.dtype in [paddle.float32, paddle.float64]: + return F.relu6(x) if paddle.is_complex(x): return paddle.complex(F.relu6(x.real()), F.relu6(x.imag())) - return F.relu6(x) + return F.relu6(x.cast("float32")).cast(x.dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float", "complex")}, - backend_version, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version ) def logsigmoid( - x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None + input: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: - if paddle.is_complex(x): + if input.dtype in [paddle.float32, paddle.float64]: + return F.log_sigmoid(input) + if paddle.is_complex(input): return paddle_backend.log( paddle_backend.divide( - 1.0, (paddle_backend.add(1.0, paddle_backend.exp(-x))) + 1.0, (paddle_backend.add(1.0, paddle_backend.exp(-input))) ) ) - return F.log_sigmoid(x) + return F.log_sigmoid(input.cast("float32")).cast(input.dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float", "complex")}, - backend_version, -) def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: + if x.dtype in [paddle.float32, paddle.float64]: + return F.selu(x) if paddle.is_complex(x): alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 @@ -81,26 +94,23 @@ def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. ), ) return ret - return F.selu(x) + return F.selu(x.cast("float32")).cast(x.dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float", "complex")}, - backend_version, -) def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: + if x.dtype in [paddle.float32, paddle.float64]: + return F.silu(x) if paddle.is_complex(x): return x * (1.0 / (1.0 + paddle_backend.exp(-x))) - return F.silu(x) + return F.silu(x.cast("float32")).cast(x.dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float", "complex")}, - backend_version, -) def elu( x: paddle.Tensor, /, *, alpha: float = 1.0, out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: + if x.dtype in [paddle.float32, paddle.float64]: + return F.elu(x, alpha=alpha) + if paddle.is_complex(x): ret = ( paddle_backend.where( @@ -110,4 +120,4 @@ def elu( ), ) return ret - return F.elu(x, alpha=alpha) + return F.elu(x.cast("float32"), alpha).cast(x.dtype) diff --git a/ivy/functional/backends/paddle/experimental/creation.py b/ivy/functional/backends/paddle/experimental/creation.py index 409f167686c63..1de567c08e0bc 100644 --- a/ivy/functional/backends/paddle/experimental/creation.py +++ b/ivy/functional/backends/paddle/experimental/creation.py @@ -7,6 +7,7 @@ from ivy.functional.backends.paddle.device import to_device from ivy.func_wrapper import ( with_supported_dtypes, + with_unsupported_device_and_dtypes, ) @@ -101,7 +102,7 @@ def tril_indices( @with_supported_dtypes( - {"2.4.2 and below": ("float16", "float32", "float64", "int32", "int64")}, + {"2.4.2 and below": ("float64", "float32", "int32", "int64")}, backend_version, ) def unsorted_segment_min( @@ -153,17 +154,6 @@ def blackman_window( ).cast(dtype) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "float", - "int", - "complex", - ) - }, - backend_version, -) def unsorted_segment_sum( data: paddle.Tensor, segment_ids: paddle.Tensor, @@ -176,16 +166,39 @@ def unsorted_segment_sum( ivy.utils.assertions.check_unsorted_segment_min_valid_params( data, segment_ids, num_segments ) + + # Sum computation in paddle does not support int32, so needs to + # be converted to float32 + needs_conv = False + if data.dtype == paddle.int32: + data = paddle.cast(data, "float32") + needs_conv = True + res = paddle.zeros((num_segments,) + tuple(data.shape[1:]), dtype=data.dtype) + for i in range(num_segments): mask_index = segment_ids == i if paddle.any(mask_index): res[i] = paddle.sum(data[mask_index], axis=0) + + # condition for converting float32 back to int32 + if needs_conv is True: + res = paddle.cast(res, "int32") + return res -@with_supported_dtypes( - {"2.5.1 and below": ("int32", "int64", "float32", "float64")}, +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "int8", + "int16", + "uint8", + "complex", + ) + } + }, backend_version, ) def trilu( diff --git a/ivy/functional/backends/paddle/experimental/elementwise.py b/ivy/functional/backends/paddle/experimental/elementwise.py index 1b006730b8c97..8fe58b9e1519a 100644 --- a/ivy/functional/backends/paddle/experimental/elementwise.py +++ b/ivy/functional/backends/paddle/experimental/elementwise.py @@ -17,7 +17,10 @@ from .. import backend_version -@with_supported_dtypes({"2.5.1 and below": ("float",)}, backend_version) +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64")}, + backend_version, +) def lgamma( x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: @@ -25,7 +28,7 @@ def lgamma( @with_supported_dtypes( - {"2.5.1 and below": ("float", "int32", "int64")}, + {"2.5.1 and below": ("float64", "float32", "int32", "int64")}, backend_version, ) def fmax( @@ -40,8 +43,8 @@ def fmax( return paddle.fmax(x1, x2) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, backend_version +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("float16",)}}, backend_version ) def sinc(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: y = ivy.pi * paddle.where(x == 0, paddle.to_tensor(1.0e-20, dtype=x.dtype), x) @@ -104,8 +107,8 @@ def copysign( return result -@with_supported_dtypes( - {"2.5.1 and below": ("float", "int32", "int64")}, backend_version +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("uint8", "int8", "int16", "float16")}}, backend_version ) def nansum( x: paddle.Tensor, @@ -122,7 +125,9 @@ def nansum( return result -@with_supported_dtypes({"2.5.1 and below": "float"}, backend_version) +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("float16",)}}, backend_version +) def isclose( a: paddle.Tensor, b: paddle.Tensor, @@ -136,9 +141,6 @@ def isclose( return paddle.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) -@with_supported_dtypes( - {"2.5.1 and below": ("float", "bool", "int32", "int64")}, backend_version -) def diff( x: Union[paddle.Tensor, list, tuple], /, @@ -150,6 +152,8 @@ def diff( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: ret_dtype = x.dtype + if x.dtype in [paddle.int8, paddle.int16, paddle.uint8, paddle.float16]: + x = x.cast("float32") def _tensor(val): if val is not None and not isinstance(val, paddle.Tensor): @@ -184,7 +188,24 @@ def hypot( raise IvyNotImplementedException() -@with_supported_dtypes({"2.5.1 and below": "float"}, backend_version) +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "int8", + "int16", + "int32", + "int64", + "uint8", + "float16", + "complex64", + "complex128", + "bool", + ) + } + }, + backend_version, +) def allclose( x1: paddle.Tensor, x2: paddle.Tensor, @@ -248,8 +269,22 @@ def nextafter( ] -@with_supported_dtypes( - {"2.5.1 and below": ("float", "int32", "int64", "bool")}, backend_version +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "float16", + "bool", + ) + } + }, + backend_version, ) def zeta( x: paddle.Tensor, @@ -572,8 +607,10 @@ def count_nonzero( @with_supported_dtypes( { "2.5.1 and below": ( - "complex", - "float", + "complex64", + "complex128", + "float32", + "float64", "int32", "int64", ) @@ -592,7 +629,12 @@ def modf( @with_supported_dtypes( - {"2.5.0 and below": ("float",)}, + { + "2.5.0 and below": ( + "float32", + "float64", + ) + }, backend_version, ) def digamma( diff --git a/ivy/functional/backends/paddle/experimental/layers.py b/ivy/functional/backends/paddle/experimental/layers.py index 23a774d4274ea..1369a98b8d6b3 100644 --- a/ivy/functional/backends/paddle/experimental/layers.py +++ b/ivy/functional/backends/paddle/experimental/layers.py @@ -297,7 +297,6 @@ def dct( raise IvyNotImplementedException() -@with_supported_dtypes({"2.5.1 and below": "complex"}, backend_version) def fft( x: paddle.Tensor, dim: int, @@ -331,6 +330,11 @@ def fft( f" {valid_norm_modes}" ) + if x.dtype in [paddle.int64, paddle.float64, paddle.complex128]: + x = x.cast(paddle.complex128) + else: + x = x.cast(paddle.complex64) + return paddle.fft.fft(x, n, dim, norm=norm) diff --git a/ivy/functional/backends/paddle/experimental/linear_algebra.py b/ivy/functional/backends/paddle/experimental/linear_algebra.py index 530354955e006..abc29c7dc5b81 100644 --- a/ivy/functional/backends/paddle/experimental/linear_algebra.py +++ b/ivy/functional/backends/paddle/experimental/linear_algebra.py @@ -5,16 +5,15 @@ # local from ivy.functional.ivy.experimental.linear_algebra import _check_valid_dimension_size from ivy.func_wrapper import ( + with_unsupported_device_and_dtypes, with_supported_device_and_dtypes, - with_supported_dtypes, ) from ivy.utils.exceptions import IvyNotImplementedException from .. import backend_version -@with_supported_dtypes( - {"2.5.1 and below": ("float", "int32", "int64")}, - backend_version, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("int8", "int16", "uint8", "float16")}}, backend_version ) def diagflat( x: paddle.Tensor, @@ -46,8 +45,8 @@ def diagflat( )(diag) -@with_supported_dtypes( - {"2.5.1 and below": ("float", "int32", "int64")}, backend_version +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("int8", "uint8", "int16")}}, backend_version ) def kron( a: paddle.Tensor, @@ -70,34 +69,16 @@ def matrix_exp( raise IvyNotImplementedException() -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "complex")}, backend_version -) def eig( x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None ) -> Tuple[paddle.Tensor]: return paddle.linalg.eig(x) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "complex")}, backend_version -) def eigvals(x: paddle.Tensor, /) -> paddle.Tensor: return paddle.linalg.eig(x)[0] -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "float", - "int32", - "int64", - "complex", - ) - }, - backend_version, -) def adjoint( x: paddle.Tensor, /, diff --git a/ivy/functional/backends/paddle/experimental/losses.py b/ivy/functional/backends/paddle/experimental/losses.py index 732c25e0f03fe..a582b43e15fb2 100644 --- a/ivy/functional/backends/paddle/experimental/losses.py +++ b/ivy/functional/backends/paddle/experimental/losses.py @@ -4,12 +4,27 @@ import paddle.nn.functional as F # local -from ivy.func_wrapper import with_supported_dtypes, with_unsupported_device_and_dtypes +from ivy.func_wrapper import with_unsupported_device_and_dtypes from . import backend_version -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, backend_version +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "float16", + "int8", + "int16", + "int32", + "int64", + "uint8", + "complex64", + "complex128", + "bool", + ) + } + }, + backend_version, ) def l1_loss( input: paddle.Tensor, @@ -21,7 +36,23 @@ def l1_loss( return F.l1_loss(input, target, reduction=reduction) -@with_supported_dtypes({"2.5.1 and below": ("float",)}, backend_version) +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "int8", + "int16", + "int32", + "int64", + "uint8", + "complex64", + "complex128", + "bool", + ) + } + }, + backend_version, +) def smooth_l1_loss( input: paddle.Tensor, target: paddle.Tensor, diff --git a/ivy/functional/backends/paddle/experimental/manipulation.py b/ivy/functional/backends/paddle/experimental/manipulation.py index abe17df2ddc58..88099753ec83b 100644 --- a/ivy/functional/backends/paddle/experimental/manipulation.py +++ b/ivy/functional/backends/paddle/experimental/manipulation.py @@ -12,6 +12,7 @@ ) from numbers import Number + from .. import backend_version from ivy.func_wrapper import with_unsupported_device_and_dtypes, with_supported_dtypes import paddle @@ -87,21 +88,6 @@ ] -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "float16", - "float32", - "float64", - "int32", - "int64", - "complex64", - "complex128", - ) - }, - backend_version, -) def moveaxis( a: paddle.Tensor, source: Union[int, Sequence[int]], @@ -115,6 +101,8 @@ def moveaxis( source = list(source) if isinstance(destination, tuple): source = list(destination) + if a.dtype in [paddle.int8, paddle.int16, paddle.uint8]: + return paddle.moveaxis(a.cast("float32"), source, destination).cast(a.dtype) return paddle.moveaxis(a, source, destination) @@ -172,8 +160,21 @@ def pad( ) -@with_supported_dtypes( - {"2.5.1 and below": ("float", "int32", "int64")}, backend_version +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "int8", + "int16", + "uint8", + "float16", + "complex64", + "complex128", + "bool", + ) + } + }, + backend_version, ) def heaviside( x1: paddle.Tensor, @@ -185,10 +186,6 @@ def heaviside( return paddle.heaviside(x1, x2) -@with_supported_dtypes( - {"2.5.1 and below": ("float", "int32", "int64", "bool")}, - backend_version, -) def flipud( m: paddle.Tensor, /, @@ -196,6 +193,8 @@ def flipud( copy: Optional[bool] = None, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + if m.dtype in [paddle.int8, paddle.int16, paddle.uint8, paddle.float16]: + return paddle.flip(m.cast("float32"), axis=0).cast(m.dtype) return paddle.flip(m, axis=0) @@ -289,10 +288,6 @@ def top_k( return topk_res(val, indices) -@with_supported_dtypes( - {"2.5.1 and below": ("float", "int32", "int64", "bool")}, - backend_version, -) def fliplr( m: paddle.Tensor, /, @@ -300,6 +295,8 @@ def fliplr( copy: Optional[bool] = None, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + if m.dtype in [paddle.int8, paddle.int16, paddle.uint8, paddle.float16]: + return paddle.flip(m.cast("float32"), axis=1).cast(m.dtype) return paddle.flip(m, axis=1) @@ -342,10 +339,6 @@ def _chbevl(x, vals): return paddle_backend.where(paddle_backend.less_equal(x, 8.0), _i0_1(x), _i0_2(x)) -@with_supported_dtypes( - {"2.5.1 and below": ("float", "int", "uint8", "complex")}, - backend_version, -) def flatten( x: paddle.Tensor, /, @@ -361,11 +354,20 @@ def flatten( return x def _flatten(x, start_dim, end_dim): - if paddle.is_complex(x): - return paddle.complex( - paddle.flatten(x.real(), start_axis=start_dim, stop_axis=end_dim), - paddle.flatten(x.imag(), start_axis=start_dim, stop_axis=end_dim), - ) + if x.dtype in [ + paddle.float16, + paddle.complex64, + paddle.complex128, + paddle.bool, + ]: + if paddle.is_complex(x): + return paddle.complex( + paddle.flatten(x.real(), start_axis=start_dim, stop_axis=end_dim), + paddle.flatten(x.imag(), start_axis=start_dim, stop_axis=end_dim), + ) + return paddle.flatten( + x.cast("float32"), start_axis=start_dim, stop_axis=end_dim + ).cast(x.dtype) return paddle.flatten(x, start_axis=start_dim, stop_axis=end_dim) if order == "F": @@ -477,15 +479,12 @@ def atleast_3d( return res -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float", - "int32", - "int64", - "uint8", - ) - }, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("int8",)}}, + backend_version, +) +@with_supported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("int32", "int64", "float32", "float64")}}, backend_version, ) def take_along_axis( @@ -537,10 +536,22 @@ def take_along_axis( arr = ivy.concat([arr, fill_arr], axis=axis) indices = ivy.where(indices < 0, arr.shape[axis] + indices, indices) - if paddle.is_complex(arr): - return paddle.complex( - paddle.take_along_axis(arr.real(), indices, axis), - paddle.take_along_axis(arr.imag(), indices, axis), + if arr.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.complex64, + paddle.complex128, + paddle.bool, + ]: + if paddle.is_complex(arr): + return paddle.complex( + paddle.take_along_axis(arr.real(), indices, axis), + paddle.take_along_axis(arr.imag(), indices, axis), + ) + return paddle.take_along_axis(arr.cast("float32"), indices, axis).cast( + arr.dtype ) return paddle.take_along_axis(arr, indices, axis) @@ -669,17 +680,8 @@ def unique_consecutive( ) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float32", - "float64", - "int32", - "int64", - "float16", - ) - }, - backend_version, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("int8", "int16", "uint8", "float16")}}, backend_version ) def fill_diagonal( a: paddle.Tensor, diff --git a/ivy/functional/backends/paddle/experimental/statistical.py b/ivy/functional/backends/paddle/experimental/statistical.py index 45254bf397893..336fe507f94bc 100644 --- a/ivy/functional/backends/paddle/experimental/statistical.py +++ b/ivy/functional/backends/paddle/experimental/statistical.py @@ -11,8 +11,20 @@ from . import backend_version -@with_supported_dtypes( - {"2.5.1 and below": ("float", "int32", "int64", "bool", "complex")}, +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "int8", + "int16", + "uint8", + "float16", + "complex64", + "complex128", + "bool", + ) + } + }, backend_version, ) def median( @@ -28,11 +40,14 @@ def median( # only axis in the tensor so it needs to be handled manually ret_dtype = input.dtype - if paddle.is_complex(input): - ret = paddle.complex( - paddle.median(input.real(), axis=axis, keepdim=True), - paddle.median(input.imag(), axis=axis, keepdim=True), - ) + if input.dtype not in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]: + if paddle.is_complex(input): + ret = paddle.complex( + paddle.median(input.real(), axis=axis, keepdim=True), + paddle.median(input.imag(), axis=axis, keepdim=True), + ) + else: + ret = paddle.median(input.cast("float32"), axis=axis, keepdim=True) else: ret = paddle.median(input, axis=axis, keepdim=True) if not keepdims: @@ -47,10 +62,6 @@ def median( return ret.astype(ret_dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float",)}, - backend_version, -) def nanmean( a: paddle.Tensor, /, @@ -64,11 +75,14 @@ def nanmean( a = a.cast( ret_dtype ) # this is necessary to match other FWs behaviour which cast before calculation - if paddle.is_complex(a): - ret = paddle.complex( - paddle.nanmean(a.real(), axis=axis, keepdim=keepdims), - paddle.nanmean(a.imag(), axis=axis, keepdim=keepdims), - ) + if a.dtype not in [paddle.int64, paddle.float32, paddle.float64]: + if paddle.is_complex(a): + ret = paddle.complex( + paddle.nanmean(a.real(), axis=axis, keepdim=keepdims), + paddle.nanmean(a.imag(), axis=axis, keepdim=keepdims), + ) + else: + ret = paddle.nanmean(a.cast("float32"), axis=axis, keepdim=keepdims) else: ret = paddle.nanmean(a, axis=axis, keepdim=keepdims) @@ -345,10 +359,6 @@ def histogram( return paddle.histogram(a, bins=bins, min=min_range, max=max_range) -@with_supported_dtypes( - {"2.5.1 and below": ("float", "int32", "int64")}, - backend_version, -) def nanmedian( input: paddle.Tensor, /, @@ -359,13 +369,26 @@ def nanmedian( overwrite_input: Optional[bool] = False, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: - if dtype is None: - dtype = input.dtype + if input.dtype not in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]: + if dtype is None: + dtype = input.dtype + input = input.cast("float32") + paddle.nanmedian(x=input, axis=axis, keepdim=keepdims).cast(dtype) return paddle.nanmedian(x=input, axis=axis, keepdim=keepdims).cast(dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float",)}, +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "int8", + "int16", + "uint8", + "float16", + "bool", + ) + } + }, backend_version, ) def unravel_index( @@ -386,8 +409,22 @@ def unravel_index( return tuple(reversed(coord)) -@with_supported_dtypes( - {"2.5.1 and below": ("int32", "int64")}, +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "int8", + "int16", + "uint8", + "float16", + "float32", + "float64", + "complex64", + "complex128", + "bool", + ) + } + }, backend_version, ) def bincount( @@ -403,7 +440,6 @@ def bincount( ) -@with_supported_dtypes({"2.5.1 and below": ("float",)}, backend_version) def igamma( a: paddle.Tensor, /, @@ -413,15 +449,20 @@ def igamma( ) -> paddle.Tensor: results = [] ret_dtype = a.dtype if out is None else out.dtype + if paddle.float16 in [a.dtype, x.dtype]: + a = a.astype("float32") + x = x.astype("float32") for ai, xi in zip(a.flatten(), x.flatten()): + ai = ai.astype("float64") + xi = xi.astype("float64") - def _integrand(t): + def integrand(t): return paddle.exp(-t) * paddle.pow(t, ai - 1) intervals = paddle.linspace(0, xi, 10001).astype("float64") interval_width = xi / 10000 - values = _integrand(intervals) + values = integrand(intervals) integral = paddle.multiply((values[:-1] + values[1:]) / 2, interval_width) result = paddle.divide(paddle.sum(integral), paddle.exp(paddle.lgamma(ai))) results.append(result) @@ -429,7 +470,6 @@ def _integrand(t): return paddle.to_tensor(results, dtype=ret_dtype).reshape(a.shape) -@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, backend_version) def cov( x1: paddle.Tensor, x2: paddle.Tensor = None, @@ -464,6 +504,15 @@ def cov( else: ddof = 0 + if dtype is None: + x1 = x1.astype("float64") + if x2 is not None: + x2 = x2.astype("float64") + else: + x1 = x1.astype(dtype) + if x2 is not None: + x2 = x2.astype(dtype) + X = x1 if not rowVar and X.shape[0] != 1: X = paddle.transpose(X, perm=tuple(range(len(X.shape) - 1, -1, -1))) @@ -484,8 +533,8 @@ def cov( ) -@with_supported_dtypes( - {"2.5.1 and below": ("int64", "float64", "complex")}, backend_version +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("uint16", "bfloat16")}}, backend_version ) def cummax( x: paddle.Tensor, @@ -497,7 +546,11 @@ def cummax( dtype: Optional[paddle.dtype] = None, out: Optional[paddle.Tensor] = None, ) -> Tuple[paddle.Tensor, paddle.Tensor]: - if x.dtype in (paddle.complex128, paddle.complex64): + if x.dtype in (paddle.bool, paddle.float16): + x = paddle.cast(x, "float64") + elif x.dtype in (paddle.int16, paddle.int8, paddle.uint8): + x = paddle.cast(x, "int64") + elif x.dtype in (paddle.complex128, paddle.complex64): x = paddle.cast(paddle.real(x), "float64") if not (exclusive or reverse): @@ -608,8 +661,9 @@ def __get_index(lst, indices=None, prefix=None): return indices -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, backend_version +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("uint8", "int8", "int16")}}, + backend_version, ) def cummin( x: paddle.Tensor, diff --git a/ivy/functional/backends/paddle/general.py b/ivy/functional/backends/paddle/general.py index f2663cfcdadea..f4258c4f452b9 100644 --- a/ivy/functional/backends/paddle/general.py +++ b/ivy/functional/backends/paddle/general.py @@ -9,10 +9,8 @@ # local import ivy import ivy.functional.backends.paddle as paddle_backend -from ivy import with_supported_dtypes from ivy.functional.ivy.general import _broadcast_to from ivy.utils.exceptions import _check_inplace_update_support -from . import backend_version def is_native_array(x, /, *, exclusive=False): @@ -102,20 +100,6 @@ def to_list(x: paddle.Tensor, /) -> list: return x.tolist() -@with_supported_dtypes( - { - "2.5.1 and below": [ - "float16", - "float32", - "float64", - "int16", - "int32", - "int64", - "uint8", - ] - }, - backend_version, -) def gather( params: paddle.Tensor, indices: paddle.Tensor, @@ -166,23 +150,20 @@ def _gather(params1): if batch_dims is not None: batch_dims = batch_dims % params.ndim ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) + if params.dtype in [ + paddle.int8, + paddle.int16, + paddle.float16, + paddle.complex64, + paddle.complex128, + paddle.bool, + ]: + if paddle.is_complex(params): + return paddle.complex(_gather(params.real()), _gather(params.imag())) + return _gather(params.cast("float32")).cast(params.dtype) return _gather(params) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "float16", - "float32", - "float64", - "int16", - "int32", - "int64", - ) - }, - backend_version, -) def gather_nd( params: paddle.Tensor, indices: paddle.Tensor, @@ -282,7 +263,21 @@ def gather_nd( # flat_indices now has shape [(B1.B2), i1, ..., iK, C] indices = paddle_backend.concat((index_grid, flat_indices), axis=-1) # indices has shape [(B1.B2), i1, ..., iK, 2+C] - out = paddle.gather_nd(params, indices) + if params.dtype in [ + paddle.int8, + paddle.float16, + paddle.complex64, + paddle.complex128, + ]: + if paddle.is_complex(params): + out = paddle.complex( + paddle.gather_nd(params.real(), indices), + paddle.gather_nd(params.imag(), indices), + ) + else: + out = paddle.gather_nd(params.cast("float32"), indices).cast(params.dtype) + else: + out = paddle.gather_nd(params, indices) # out has shape [(B1.B2), i1, ..., iK, N-C]. Now we reshape batch to # its original form. out_shape = out.shape @@ -386,10 +381,6 @@ def scatter_flat( ) -@with_supported_dtypes( - {"2.5.1 and below": ("int32", "int64", "float32", "float64", "complex")}, - backend_version, -) def scatter_nd( indices: paddle.Tensor, updates: paddle.Tensor, @@ -484,6 +475,23 @@ def scatter_nd( updates.imag(), ) ret = paddle.complex(result_real, result_imag) + elif target_dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + target, updates, updates_ = ( + target.cast("float32"), + updates.cast("float32"), + updates_.cast("float32"), + ) + ret = paddle.scatter_nd_add( + paddle.scatter_nd_add(target, indices, -updates_), + indices, + updates, + ).cast(target_dtype) else: ret = paddle.scatter_nd_add( paddle.scatter_nd_add(target, indices, -updates_), diff --git a/ivy/functional/backends/paddle/linear_algebra.py b/ivy/functional/backends/paddle/linear_algebra.py index 1222bbdf881f8..8a99ab77110fc 100644 --- a/ivy/functional/backends/paddle/linear_algebra.py +++ b/ivy/functional/backends/paddle/linear_algebra.py @@ -11,19 +11,28 @@ from ivy.utils.exceptions import IvyNotImplementedException import ivy.functional.backends.paddle as paddle_backend from . import backend_version -from ivy.func_wrapper import ( - with_unsupported_device_and_dtypes, - with_unsupported_dtypes, - with_supported_dtypes, -) +from ivy.func_wrapper import with_unsupported_device_and_dtypes, with_unsupported_dtypes from .elementwise import _elementwise_helper # Array API Standard # # -------------------# -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "int8", + "int16", + "int32", + "int64", + "uint8", + "float16", + "complex", + "bool", + ) + } + }, backend_version, ) def cholesky( @@ -32,10 +41,6 @@ def cholesky( return paddle.linalg.cholesky(x, upper=upper) -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, - backend_version, -) def cross( x1: paddle.Tensor, x2: paddle.Tensor, @@ -56,30 +61,53 @@ def _cross(x1, x2, axisa, axisb, axisc, axis): return paddle.moveaxis(ret, 1, axisc) x1, x2, ret_dtype = _elementwise_helper(x1, x2) + if x1.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.complex64, + paddle.complex128, + paddle.bool, + ]: + if paddle.is_complex(x1): + return paddle.complex( + _cross(x1.real(), x2.real(), axisa, axisb, axisc, axis), + _cross(x1.real(), x2.real(), axisa, axisb, axisc, axis), + ) + return _cross( + x1.cast("float32"), + x2.cast("float32"), + axisa, + axisb, + axisc, + axis, + ).cast(ret_dtype) return _cross(x1, x2, axisa, axisb, axisc, axis) -@with_supported_dtypes({"2.5.1 and below": ("float",)}, backend_version) +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, + backend_version, +) def det(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: - ret = paddle.linalg.det(x) + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.int32, + paddle.int64, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + ret = paddle.linalg.det(x.cast("float32")).cast(x.dtype) + else: + ret = paddle.linalg.det(x) if x.ndim == 2: ret = paddle_backend.squeeze(ret, axis=0) return ret -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int32", - "int64", - "float16", - "float32", - "float64", - ) - }, - backend_version, -) def diagonal( x: paddle.Tensor, /, @@ -89,13 +117,25 @@ def diagonal( axis2: int = -1, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.complex64, + paddle.complex128, + ]: + if paddle.is_complex(x): + return paddle.complex( + paddle.diagonal(x.real(), offset=offset, axis1=axis1, axis2=axis2), + paddle.diagonal(x.imag(), offset=offset, axis1=axis1, axis2=axis2), + ) + return paddle.diagonal( + x.cast("float32"), offset=offset, axis1=axis1, axis2=axis2 + ).cast(x.dtype) return paddle.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "complex")}, - backend_version, -) def eigh( x: paddle.Tensor, /, @@ -110,10 +150,6 @@ def eigh( return result_tuple(eigenvalues, eigenvectors) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "complex")}, - backend_version, -) def eigvalsh( x: paddle.Tensor, /, @@ -124,20 +160,26 @@ def eigvalsh( return paddle.linalg.eigvalsh(x, UPLO=UPLO) -@with_supported_dtypes( - {"2.5.1 and below": ("float",)}, - backend_version, -) def inner( x1: paddle.Tensor, x2: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: x1, x2 = ivy.promote_types_of_inputs(x1, x2) ret_dtype = x1.dtype + if x1.dtype in [ + paddle.int8, + paddle.int16, + paddle.int32, + paddle.int64, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + x1, x2 = x1.cast("float32"), x2.cast("float32") return paddle.inner(x1, x2).squeeze().cast(ret_dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def inv( @@ -148,20 +190,21 @@ def inv( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: ret_dtype = x.dtype + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.int32, + paddle.int64, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + x = x.cast("float32") if adjoint: x = paddle.moveaxis(x, -2, -1).conj() return paddle.inverse(x).cast(ret_dtype) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float", - "complex", - ) - }, - backend_version, -) def matmul( x1: paddle.Tensor, x2: paddle.Tensor, @@ -174,21 +217,37 @@ def matmul( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: x1, x2 = ivy.promote_types_of_inputs(x1, x2) + ret_dtype = x1.dtype + if x1.dtype in [ + paddle.int8, + paddle.int16, + paddle.int32, + paddle.int64, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + x1, x2 = x1.cast("float32"), x2.cast("float32") if adjoint_a: x1 = paddle.moveaxis(x1, -2, -1).conj() if adjoint_b: x2 = paddle.moveaxis(x2, -2, -1).conj() - ret = paddle.matmul(x1, x2, transpose_x=transpose_a, transpose_y=transpose_b) + ret = paddle.matmul(x1, x2, transpose_x=transpose_a, transpose_y=transpose_b).cast( + ret_dtype + ) # handle case where ret should be 0d. if x1.ndim == 1 and x2.ndim == 1: - return ret.squeeze() + ret_dtype = ret.dtype + if ret_dtype in [paddle.int16]: + ret = ret.cast(paddle.int32) + return ret.squeeze().astype(ret_dtype) return ret -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def matrix_norm( @@ -259,10 +318,6 @@ def matrix_norm( return ret -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "complex")}, - backend_version, -) def eig( x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None ) -> Tuple[paddle.Tensor]: @@ -273,8 +328,8 @@ def eig( return result_tuple(eigenvalues, eigenvectors) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def matrix_power( @@ -283,8 +338,8 @@ def matrix_power( return paddle.linalg.matrix_power(x, n) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def matrix_rank( @@ -321,18 +376,6 @@ def matrix_rank( return ret -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "float", - "int32", - "int64", - "complex", - ) - }, - backend_version, -) def matrix_transpose( x: paddle.Tensor, /, @@ -345,35 +388,39 @@ def matrix_transpose( x = paddle.conj(x) perm = list(range(x.ndim)) perm[-1], perm[-2] = perm[-2], perm[-1] + if x.dtype in [paddle.int8, paddle.int16, paddle.uint8]: + return paddle.transpose(x.cast("float32"), perm=perm).cast(x.dtype) return paddle.transpose(x, perm=perm) -@with_supported_dtypes( - {"2.5.1 and below": ("float",)}, - backend_version, -) def outer( x1: paddle.Tensor, x2: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: x1, x2 = ivy.promote_types_of_inputs(x1, x2) - return paddle.outer(x1, x2) + ret_dtype = x1.dtype + if x1.dtype in [ + paddle.int8, + paddle.int16, + paddle.int32, + paddle.int64, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + x1, x2 = x1.cast("float32"), x2.cast("float32") + return paddle.outer(x1, x2).cast(ret_dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "complex")}, - backend_version, -) def pinv( x: paddle.Tensor, /, *, rtol: Optional[Union[float, Tuple[float]]] = None, - hermitian: Optional[bool] = False, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: if rtol is None: - return paddle.linalg.pinv(x, hermitian=hermitian) - return paddle.linalg.pinv(x, rcond=rtol, hermitian=hermitian) + return paddle.linalg.pinv(x) + return paddle.linalg.pinv(x, rcond=rtol) def tensorsolve( @@ -388,8 +435,8 @@ def tensorsolve( raise IvyNotImplementedException() -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def qr( @@ -404,8 +451,8 @@ def qr( return res(q, r) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def slogdet( @@ -423,8 +470,8 @@ def slogdet( return results(sign, logabsdet) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def solve( @@ -451,14 +498,26 @@ def solve( return ret -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def svd( x: paddle.Tensor, /, *, full_matrices: bool = True, compute_uv: bool = True ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]: - ret = paddle.linalg.svd(x, full_matrices=full_matrices) + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.int32, + paddle.int64, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + ret = paddle.linalg.svd(x.cast("float32"), full_matrices=full_matrices) + ret = tuple(r.cast(x.dtype) for r in ret) + else: + ret = paddle.linalg.svd(x, full_matrices=full_matrices) if compute_uv: results = namedtuple("svd", "U S Vh") return results(*ret) @@ -467,8 +526,8 @@ def svd( return results(ret[1]) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def svdvals( @@ -477,10 +536,6 @@ def svdvals( return paddle_backend.svd(x)[1] -@with_supported_dtypes( - {"2.5.1 and below": ("float",)}, - backend_version, -) def tensordot( x1: paddle.Tensor, x2: paddle.Tensor, @@ -490,17 +545,33 @@ def tensordot( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: x1, x2 = ivy.promote_types_of_inputs(x1, x2) + ret_dtype = x1.dtype + if x1.dtype in [ + paddle.int8, + paddle.int16, + paddle.int32, + paddle.int64, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + x1, x2 = x1.cast("float32"), x2.cast("float32") ret = paddle.tensordot(x1, x2, axes=axes) - return ret.squeeze() if x1.ndim == axes else ret + return ret.squeeze().cast(ret_dtype) if x1.ndim == axes else ret.cast(ret_dtype) -@with_supported_dtypes( +@with_unsupported_device_and_dtypes( { - "2.5.1 and below": ( - "int32", - "int64", - "float", - ) + "2.5.1 and below": { + "cpu": ( + "int8", + "int16", + "unsigned", + "float16", + "complex", + "bool", + ) + } }, backend_version, ) @@ -566,17 +637,6 @@ def vector_norm( # ----- # -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float", - "int32", - "int64", - "complex", - ) - }, - backend_version, -) def diag( x: paddle.Tensor, /, @@ -584,15 +644,24 @@ def diag( k: int = 0, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: - if paddle.is_complex(x): - return paddle.complex( - paddle.diag(x.real(), offset=k), paddle.diag(x.imag(), offset=k) - ) + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.complex64, + paddle.complex128, + paddle.bool, + ]: + if paddle.is_complex(x): + return paddle.complex( + paddle.diag(x.real(), offset=k), paddle.diag(x.imag(), offset=k) + ) + return paddle.diag(x.cast("float32"), offset=k).cast(x.dtype) return paddle.diag(x, offset=k) @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("uint8", "int8", "int16", "complex")}}, + {"2.5.1 and below": {"cpu": ("uint8", "int8", "int16", "complex64", "complex128")}}, backend_version, ) def vander( diff --git a/ivy/functional/backends/paddle/manipulation.py b/ivy/functional/backends/paddle/manipulation.py index 80c685ff15cee..bf654b3f253a4 100644 --- a/ivy/functional/backends/paddle/manipulation.py +++ b/ivy/functional/backends/paddle/manipulation.py @@ -7,7 +7,7 @@ # local import ivy import ivy.functional.backends.paddle as paddle_backend -from ivy.func_wrapper import with_supported_dtypes +from ivy.func_wrapper import with_unsupported_device_and_dtypes # noinspection PyProtectedMember from . import backend_version @@ -69,10 +69,6 @@ def expand_dims( return x.reshape(out_shape) -@with_supported_dtypes( - {"2.5.1 and below": ("float", "int32", "int64", "bool")}, - backend_version, -) def flip( x: paddle.Tensor, /, @@ -83,21 +79,11 @@ def flip( ) -> paddle.Tensor: if axis is None: axis = list(range(x.ndim)) + if x.dtype in [paddle.int8, paddle.int16, paddle.uint8, paddle.float16]: + return paddle.flip(x.cast("float32"), axis).cast(x.dtype) return paddle.flip(x, axis) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "float", - "int32", - "int64", - "complex", - ) - }, - backend_version, -) def permute_dims( x: paddle.Tensor, /, @@ -106,6 +92,8 @@ def permute_dims( copy: Optional[bool] = None, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + if x.dtype in [paddle.int8, paddle.int16, paddle.uint8]: + return paddle.transpose(x.cast("float32"), axes).cast(x.dtype) return paddle.transpose(x, axes) @@ -163,17 +151,6 @@ def reshape( return ret -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float", - "int32", - "int64", - "complex", - ) - }, - backend_version, -) def roll( x: paddle.Tensor, /, @@ -182,22 +159,17 @@ def roll( axis: Optional[Union[int, Sequence[int]]] = None, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + return paddle.roll(x.cast("float32"), shift, axis).cast(x.dtype) return paddle.roll(x, shift, axis) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float", - "bool", - "int8", - "int32", - "int64", - "complex", - ) - }, - backend_version, -) def squeeze( x: paddle.Tensor, /, @@ -219,18 +191,13 @@ def squeeze( x_shape = x.shape x_shape.pop(axis) return paddle_backend.reshape(x, x_shape) + if x.dtype in [paddle.int16, paddle.float16]: + return paddle.squeeze(x.cast("float32"), axis=axis).cast(x.dtype) return paddle.squeeze(x, axis=axis) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float", - "int32", - "int64", - "complex", - ) - }, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("int16", "uint8", "int8", "float16")}}, backend_version, ) def stack( @@ -256,7 +223,11 @@ def stack( first_shape[:axis] + [len(arrays)] + first_shape[axis:], dtype=dtype ) - if dtype in [ + if dtype in [paddle.int8, paddle.int16, paddle.uint8, paddle.float16, paddle.bool]: + arrays = list(map(lambda x: x.cast("float32"), arrays)) + return paddle.stack(arrays, axis=axis).cast(dtype) + + elif dtype in [ paddle.complex64, paddle.complex128, ]: @@ -274,21 +245,6 @@ def stack( # ------# -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "bfloat16", - "float", - "int32", - "int64", - "uint8", - "int8", - "complex", - ) - }, - backend_version, -) def split( x: paddle.Tensor, /, @@ -331,17 +287,16 @@ def split( f" got {sum(num_or_size_splits)} which is more than x.shape[axis]" ) - if paddle.is_complex(x): - imag_list = paddle.split(x.imag(), num_or_size_splits, axis) - real_list = paddle.split(x.real(), num_or_size_splits, axis) - return [paddle.complex(a, b) for a, b in zip(real_list, imag_list)] + if x.dtype in [paddle.int16, paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): + imag_list = paddle.split(x.imag(), num_or_size_splits, axis) + real_list = paddle.split(x.real(), num_or_size_splits, axis) + return [paddle.complex(a, b) for a, b in zip(real_list, imag_list)] + ret = paddle.split(x.cast("int32"), num_or_size_splits, axis) + return [tensor.cast(x.dtype) for tensor in ret] return paddle.split(x, num_or_size_splits, axis) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64", "complex")}, - backend_version, -) def repeat( x: paddle.Tensor, /, @@ -366,26 +321,28 @@ def repeat( if axis is not None: axis = axis % x.ndim - if paddle.is_complex(x): - return paddle.complex( - paddle.repeat_interleave(x.real(), repeats=repeats, axis=axis), - paddle.repeat_interleave(x.imag(), repeats=repeats, axis=axis), - ) + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.complex64, + paddle.complex128, + paddle.bool, + ]: + if paddle.is_complex(x): + return paddle.complex( + paddle.repeat_interleave(x.real(), repeats=repeats, axis=axis), + paddle.repeat_interleave(x.imag(), repeats=repeats, axis=axis), + ) + + return paddle.repeat_interleave( + x.cast("float32"), repeats=repeats, axis=axis + ).cast(x.dtype) return paddle.repeat_interleave(x, repeats=repeats, axis=axis) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "float", - "int32", - "int64", - ) - }, - backend_version, -) def tile( x: paddle.Tensor, /, repeats: Sequence[int], *, out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: @@ -421,20 +378,11 @@ def tile( shape = paddle_backend.multiply(x.shape, repeats).tolist() return paddle.zeros(shape).cast(x.dtype) + if x.dtype in [paddle.int8, paddle.int16, paddle.uint8, paddle.float16]: + return paddle.tile(x.cast("float32"), repeats).cast(x.dtype) return paddle.tile(x, repeats) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float", - "int32", - "int64", - "complex", - ), - }, - backend_version, -) def constant_pad( x: paddle.Tensor, /, @@ -451,6 +399,16 @@ def constant_pad( else: paddings.append(item[0]) paddings.append(item[1]) + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + return paddle.nn.functional.pad( + x.cast("float32"), pad=paddings, value=value + ).cast(x.dtype) return paddle.nn.functional.pad(x=x, pad=paddings, value=value) @@ -508,18 +466,6 @@ def clip( return x -@with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "float", - "int32", - "int64", - "complex", - ) - }, - backend_version, -) def unstack( x: paddle.Tensor, /, @@ -534,10 +480,22 @@ def unstack( axis = axis % x.ndim else: axis = 0 - if paddle.is_complex(x): - real_list = paddle.unbind(x.real(), axis) - imag_list = paddle.unbind(x.imag(), axis) - ret = [paddle.complex(a, b) for a, b in zip(real_list, imag_list)] + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.complex64, + paddle.complex128, + paddle.bool, + ]: + if paddle.is_complex(x): + real_list = paddle.unbind(x.real(), axis) + imag_list = paddle.unbind(x.imag(), axis) + ret = [paddle.complex(a, b) for a, b in zip(real_list, imag_list)] + else: + ret = paddle.unbind(x.cast("float32"), axis) + ret = list(map(lambda a: a.cast(x.dtype), ret)) + else: ret = paddle.unbind(x, axis) if keepdims: diff --git a/ivy/functional/backends/paddle/random.py b/ivy/functional/backends/paddle/random.py index 9d983265e78ef..3ca640ccf531b 100644 --- a/ivy/functional/backends/paddle/random.py +++ b/ivy/functional/backends/paddle/random.py @@ -15,7 +15,7 @@ ) from ivy.func_wrapper import ( with_unsupported_device_and_dtypes, - with_supported_dtypes, + with_supported_device_and_dtypes, ) from . import backend_version @@ -23,8 +23,8 @@ # ------# -@with_supported_dtypes( - {"2.5.1 and below": ("float",)}, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("int8",)}}, backend_version, ) def random_uniform( @@ -54,8 +54,8 @@ def random_uniform( ) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64")}, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def random_normal( @@ -80,8 +80,15 @@ def random_normal( return paddle.normal(mean, std).cast(dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float",)}, +@with_supported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "float32", + "float64", + ) + } + }, backend_version, ) def multinomial( @@ -142,10 +149,6 @@ def seed(*, seed_value: int = 0) -> None: return -@with_supported_dtypes( - {"2.5.1 and below": ("int32", "int64", "float32", "float64", "complex")}, - backend_version, -) def shuffle( x: paddle.Tensor, axis: Optional[int] = 0, @@ -158,8 +161,18 @@ def shuffle( _ = paddle.seed(seed) # Use Paddle's randperm function to generate shuffled indices indices = paddle.randperm(x.ndim, dtype="int64") - if paddle.is_complex(x): - shuffled_real = paddle.index_select(x.real(), indices, axis=axis) - shuffled_imag = paddle.index_select(x.imag(), indices, axis=axis) - return paddle.complex(shuffled_real, shuffled_imag) + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.complex64, + paddle.complex128, + paddle.bool, + ]: + if paddle.is_complex(x): + shuffled_real = paddle.index_select(x.real(), indices, axis=axis) + shuffled_imag = paddle.index_select(x.imag(), indices, axis=axis) + return paddle.complex(shuffled_real, shuffled_imag) + return paddle.index_select(x.cast("float32"), indices, axis=axis).cast(x.dtype) return paddle.index_select(x, indices, axis=axis) diff --git a/ivy/functional/backends/paddle/searching.py b/ivy/functional/backends/paddle/searching.py index f058421847941..9519e03c23916 100644 --- a/ivy/functional/backends/paddle/searching.py +++ b/ivy/functional/backends/paddle/searching.py @@ -4,7 +4,7 @@ import paddle import ivy.functional.backends.paddle as paddle_backend import ivy -from ivy.func_wrapper import with_supported_dtypes +from ivy.func_wrapper import with_unsupported_device_and_dtypes from . import backend_version from .elementwise import _elementwise_helper @@ -12,15 +12,14 @@ # ------------------ # -@with_supported_dtypes( +@with_unsupported_device_and_dtypes( { - "2.5.1 and below": ( - "float", - "int16", - "int32", - "int64", - "uint8", - ) + "2.5.1 and below": { + "cpu": ( + "complex64", + "complex128", + ) + } }, backend_version, ) @@ -35,6 +34,8 @@ def argmax( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: dtype = dtype if dtype is not None else paddle.int64 + if x.dtype in [paddle.int8, paddle.float16, paddle.bool]: + x = x.cast("float32") if select_last_index: x = paddle_backend.flip(x, axis=axis) ret = paddle.argmax(x, axis=axis, keepdim=keepdims) @@ -52,17 +53,14 @@ def argmax( return ret.astype(dtype) -@with_supported_dtypes( +@with_unsupported_device_and_dtypes( { - "2.5.1 and below": ( - "float16", - "float32", - "float64", - "int16", - "int32", - "int64", - "uint8", - ) + "2.5.1 and below": { + "cpu": ( + "complex64", + "complex128", + ) + } }, backend_version, ) @@ -77,6 +75,8 @@ def argmin( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: dtype = dtype if dtype is not None else paddle.int64 + if x.dtype in [paddle.int8, paddle.float16, paddle.bool]: + x = x.cast("float32") if select_last_index: x = paddle_backend.flip(x, axis=axis) ret = paddle.argmin(x, axis=axis, keepdim=keepdims) @@ -94,21 +94,6 @@ def argmin( return ret.astype(dtype) -@with_supported_dtypes( - { - "2.5.1 and below": [ - "int16", - "int32", - "int64", - "float16", - "float32", - "float64", - "bool", - "complex", - ] - }, - backend_version, -) def nonzero( x: paddle.Tensor, /, @@ -117,11 +102,20 @@ def nonzero( size: Optional[int] = None, fill_value: Number = 0, ) -> Union[paddle.Tensor, Tuple[paddle.Tensor]]: - if paddle.is_complex(x): - real_idx = paddle.nonzero(x.real()) - imag_idx = paddle.nonzero(x.imag()) - idx = paddle.concat([real_idx, imag_idx], axis=0) - res = paddle.unique(idx, axis=0) + if x.dtype in [ + paddle.int8, + paddle.uint8, + paddle.float16, + paddle.complex64, + paddle.complex128, + ]: + if paddle.is_complex(x): + real_idx = paddle.nonzero(x.real()) + imag_idx = paddle.nonzero(x.imag()) + idx = paddle.concat([real_idx, imag_idx], axis=0) + res = paddle.unique(idx, axis=0) + else: + res = paddle.nonzero(x.cast("float32")) else: res = paddle.nonzero(x) @@ -146,19 +140,6 @@ def nonzero( return res.T -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float16", - "float32", - "float64", - "int32", - "int64", - "complex", - ) - }, - backend_version, -) def where( condition: paddle.Tensor, x1: Union[float, int, paddle.Tensor], @@ -176,43 +157,46 @@ def where( condition, x1, x2 = arrays condition = condition.cast("bool") if condition.dtype != paddle.bool else condition - if ret_dtype in [paddle.complex64, paddle.complex128]: + if ret_dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + x1 = x1.cast("float32") + x2 = x2.cast("float32") + result = paddle.where(condition, x1, x2) + elif ret_dtype in [paddle.complex64, paddle.complex128]: result_real = paddle.where(condition, paddle.real(x1), paddle.real(x2)) result_imag = paddle.where(condition, paddle.imag(x1), paddle.imag(x2)) result = paddle.complex(result_real, result_imag) else: result = paddle.where(condition, x1, x2) - return result.squeeze() if scalar_out else result + return result.squeeze().cast(ret_dtype) if scalar_out else result.cast(ret_dtype) # Extra # # ----- # -@with_supported_dtypes( - { - "2.5.1 and below": [ - "int16", - "int32", - "int64", - "float16", - "float32", - "float64", - "bool", - "complex", - ] - }, - backend_version, -) def argwhere( x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: if x.ndim == 0: return paddle.zeros(shape=[int(bool(x.item())), 0], dtype="int64") - if paddle.is_complex(x): - real_idx = paddle.nonzero(x.real()) - imag_idx = paddle.nonzero(x.imag()) - idx = paddle.concat([real_idx, imag_idx], axis=0) - return paddle.unique(idx, axis=0) + if x.dtype in [ + paddle.int8, + paddle.uint8, + paddle.float16, + paddle.complex64, + paddle.complex128, + ]: + if paddle.is_complex(x): + real_idx = paddle.nonzero(x.real()) + imag_idx = paddle.nonzero(x.imag()) + idx = paddle.concat([real_idx, imag_idx], axis=0) + return paddle.unique(idx, axis=0) + return paddle.nonzero(x.cast("float32")) return paddle.nonzero(x) diff --git a/ivy/functional/backends/paddle/set.py b/ivy/functional/backends/paddle/set.py index d2d312ddebed4..0864ed4ec3880 100644 --- a/ivy/functional/backends/paddle/set.py +++ b/ivy/functional/backends/paddle/set.py @@ -3,16 +3,13 @@ from typing import Tuple, Optional from collections import namedtuple import ivy.functional.backends.paddle as paddle_backend -from ivy.func_wrapper import with_supported_dtypes +from ivy.func_wrapper import with_unsupported_device_and_dtypes, with_unsupported_dtypes # local from . import backend_version -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, - backend_version, -) +@with_unsupported_dtypes({"2.5.1 and below": ("complex",)}, backend_version) def unique_all( x: paddle.Tensor, /, @@ -25,7 +22,16 @@ def unique_all( ["values", "indices", "inverse_indices", "counts"], ) - x_dtype = x.dtype + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + x, x_dtype = x.cast("float32"), x.dtype + else: + x_dtype = x.dtype if axis is not None: axis = axis % x.ndim values, inverse_indices, counts = paddle.unique( @@ -88,12 +94,14 @@ def unique_all( ) -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, - backend_version, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex",)}}, backend_version ) def unique_counts(x: paddle.Tensor, /) -> Tuple[paddle.Tensor, paddle.Tensor]: - x_dtype = x.dtype + if x.dtype not in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]: + x, x_dtype = x.cast("float32"), x.dtype + else: + x_dtype = x.dtype unique, counts = paddle.unique(x, return_counts=True) nan_count = paddle.count_nonzero(paddle.where(paddle.isnan(x) > 0)).numpy()[0] @@ -114,12 +122,14 @@ def unique_counts(x: paddle.Tensor, /) -> Tuple[paddle.Tensor, paddle.Tensor]: return Results(unique.cast(x_dtype), counts) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, - backend_version, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex",)}}, backend_version ) def unique_inverse(x: paddle.Tensor, /) -> Tuple[paddle.Tensor, paddle.Tensor]: - x_dtype = x.dtype + if x.dtype not in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]: + x, x_dtype = x.cast("float32"), x.dtype + else: + x_dtype = x.dtype unique, inverse_val = paddle.unique(x, return_inverse=True) nan_idx = paddle.where(paddle.isnan(x) > 0) nan_count = paddle.count_nonzero(nan_idx).numpy()[0] @@ -138,14 +148,16 @@ def unique_inverse(x: paddle.Tensor, /) -> Tuple[paddle.Tensor, paddle.Tensor]: return Results(unique.cast(x_dtype), inverse_val) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, - backend_version, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex",)}}, backend_version ) def unique_values( x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: - x_dtype = x.dtype + if x.dtype not in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]: + x, x_dtype = x.cast("float32"), x.dtype + else: + x_dtype = x.dtype nan_count = paddle.sum(paddle.isnan(x)) unique = paddle.unique(x) if nan_count > 0: diff --git a/ivy/functional/backends/paddle/sorting.py b/ivy/functional/backends/paddle/sorting.py index 31465decbc7d3..6a1d8f38f10d8 100644 --- a/ivy/functional/backends/paddle/sorting.py +++ b/ivy/functional/backends/paddle/sorting.py @@ -4,20 +4,12 @@ # local import ivy -from ivy.func_wrapper import with_supported_dtypes +from ivy.func_wrapper import with_unsupported_device_and_dtypes from . import backend_version -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float", - "int16", - "int32", - "int64", - "uint8", - ) - }, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def argsort( @@ -29,20 +21,19 @@ def argsort( stable: bool = True, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + x = x.cast("float32") return paddle.argsort(x, axis=axis, descending=descending) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float32", - "float64", - "int16", - "int32", - "int64", - "uint8", - ) - }, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def sort( @@ -54,18 +45,21 @@ def sort( stable: bool = True, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + return paddle.sort(x.cast("float32"), axis=axis, descending=descending).cast( + x.dtype + ) return paddle.sort(x, axis=axis, descending=descending) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float32", - "float64", - "int32", - "int64", - ) - }, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def searchsorted( @@ -78,6 +72,24 @@ def searchsorted( ret_dtype=paddle.int64, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + x = x.cast("float32") + + if v.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: + v = v.cast("float32") + right = True if side == "right" else False assert ivy.is_int_dtype(ret_dtype), ValueError( "only Integer data types are supported for ret_dtype." @@ -102,17 +114,8 @@ def searchsorted( return paddle.searchsorted(x, v, right=right).cast(ret_dtype) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float32", - "float64", - "int16", - "int32", - "int64", - "uint8", - ) - }, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("int8", "uint8", "int16", "float16", "complex")}}, backend_version, ) def msort( diff --git a/ivy/functional/backends/paddle/statistical.py b/ivy/functional/backends/paddle/statistical.py index e1d6a80b4cca0..0f294cde0e285 100644 --- a/ivy/functional/backends/paddle/statistical.py +++ b/ivy/functional/backends/paddle/statistical.py @@ -8,7 +8,8 @@ import ivy from ivy.utils.exceptions import IvyNotImplementedException from ivy.func_wrapper import ( - with_supported_dtypes, + with_unsupported_device_and_dtypes, + with_supported_device_and_dtypes, ) import ivy.functional.backends.paddle as paddle_backend @@ -19,20 +20,6 @@ # -------------------# -def _with_complex_support(x, fn, axis, keepdims): - if paddle.is_complex(x): - real_part = fn(x.real(), axis=axis, keepdim=keepdims) - imag_part = fn(x.imag(), axis=axis, keepdim=keepdims) - ret = paddle.complex(real_part, imag_part) - else: - ret = fn(x, axis=axis, keepdim=keepdims) - return ret - - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64", "complex")}, - backend_version, -) def min( x: paddle.Tensor, /, @@ -42,7 +29,24 @@ def min( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: ret_dtype = x.dtype - ret = _with_complex_support(x, paddle.amin, axis=axis, keepdims=keepdims) + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.bfloat16, + paddle.complex64, + paddle.complex128, + paddle.bool, + ]: + if paddle.is_complex(x): + real = paddle.amin(x.real(), axis=axis, keepdim=keepdims) + imag = paddle.amin(x.imag(), axis=axis, keepdim=keepdims) + ret = paddle.complex(real, imag) + else: + ret = paddle.amin(x.cast("float32"), axis=axis, keepdim=keepdims) + else: + ret = paddle.amin(x, axis=axis, keepdim=keepdims) # The following code is to simulate other frameworks # output shapes behaviour since min output dim is 1 in paddle if isinstance(axis, Sequence): @@ -53,10 +57,6 @@ def min( return ret.astype(ret_dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64", "complex")}, - backend_version, -) def max( x: paddle.Tensor, /, @@ -65,7 +65,33 @@ def max( keepdims: bool = False, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: - ret = _with_complex_support(x, paddle.amax, axis, keepdims) + ret_dtype = x.dtype + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.bfloat16, + paddle.float16, + paddle.complex64, + paddle.complex128, + paddle.bool, + ]: + if paddle.is_complex(x): + const = paddle.to_tensor(1j, dtype=x.dtype) + real_max = paddle.max(x.real(), axis=axis, keepdim=keepdims) + imag = paddle.where( + x.real() == real_max, x.imag(), paddle.full_like(x.imag(), -1e10) + ) + # we consider the number with the biggest real and imag part + img_max = paddle.max(imag, axis=axis, keepdim=keepdims) + img_max = paddle.cast(img_max, x.dtype) + return paddle.add( + paddle.cast(real_max, x.dtype), paddle.multiply(img_max, const) + ) + else: + ret = paddle.amax(x.cast("float32"), axis=axis, keepdim=keepdims) + else: + ret = paddle.amax(x, axis=axis, keepdim=keepdims) # The following code is to simulate other frameworks # output shapes behaviour since min output dim is 1 in paddle @@ -74,13 +100,9 @@ def max( axis = None if (x.ndim == 1 or axis is None) and not keepdims: ret = ret.squeeze() - return ret.astype(x.dtype) + return ret.astype(ret_dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "complex")}, - backend_version, -) def mean( x: paddle.Tensor, /, @@ -89,7 +111,20 @@ def mean( keepdims: bool = False, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: - ret = _with_complex_support(x, paddle.mean, axis=axis, keepdims=keepdims) + ret_dtype = x.dtype + if x.dtype not in [ + paddle.float32, + paddle.float64, + ]: + if paddle.is_complex(x): + ret = paddle.complex( + paddle.mean(x.real(), axis=axis, keepdim=keepdims), + paddle.mean(x.imag(), axis=axis, keepdim=keepdims), + ) + else: + ret = paddle.mean(x.cast("float32"), axis=axis, keepdim=keepdims) + else: + ret = paddle.mean(x, axis=axis, keepdim=keepdims) # The following code is to simulate other frameworks # output shapes behaviour since min output dim is 1 in paddle @@ -98,21 +133,9 @@ def mean( axis = None if (x.ndim == 1 or axis is None) and not keepdims: ret = ret.squeeze() - return ret.astype(x.dtype) + return ret.astype(ret_dtype) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float32", - "float64", - "int32", - "int64", - "float16", - ) - }, - backend_version, -) def prod( x: paddle.Tensor, /, @@ -154,10 +177,6 @@ def _std(x, axis, correction, keepdim): return out -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64")}, - backend_version, -) def std( x: paddle.Tensor, /, @@ -170,20 +189,6 @@ def std( return _std(x, axis, correction, keepdims).cast(x.dtype) -@with_supported_dtypes( - { - "2.5.1 and below": ( - "float16", - "float32", - "float64", - "bool", - "int32", - "int64", - "complex", - ) - }, - backend_version, -) def sum( x: paddle.Tensor, /, @@ -195,7 +200,10 @@ def sum( ) -> paddle.Tensor: dtype = x.dtype if dtype is None else dtype dtype = ivy.as_ivy_dtype(dtype) - ret = paddle.sum(x.cast(dtype), axis=axis, dtype=dtype, keepdim=keepdims) + if x.dtype in [paddle.int8, paddle.uint8]: + ret = paddle.sum(x.cast("float32"), axis=axis, dtype=dtype, keepdim=keepdims) + else: + ret = paddle.sum(x.cast(dtype), axis=axis, dtype=dtype, keepdim=keepdims) # The following code is to simulate other frameworks # output shapes behaviour since min output dim is 1 in paddle if isinstance(axis, Sequence): @@ -221,16 +229,11 @@ def var( # Extra # # ----- # -@with_supported_dtypes( +@with_supported_device_and_dtypes( { - "2.5.1 and below": ( - "complex", - "float16", - "float32", - "float64", - "int32", - "int64", - ) + "2.5.1 and below": { + "cpu": ("int32", "int64", "float64", "complex128", "float32", "complex64") + } }, backend_version, ) @@ -246,6 +249,13 @@ def cumprod( ) -> paddle.Tensor: dtype = dtype if dtype is not None else x.dtype x = paddle.cast(x, dtype) + if ivy.as_native_dtype(dtype) in [ + paddle.uint8, + paddle.int8, + paddle.int16, + paddle.float16, + ]: + x = paddle.cast(x, "float32") if not (exclusive or reverse): return paddle.cumprod(x, dim=axis).cast(dtype) elif exclusive and reverse: @@ -280,8 +290,8 @@ def cumprod( return paddle_backend.flip(x, axis=axis).cast(dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, backend_version, ) def cumsum( @@ -295,6 +305,13 @@ def cumsum( ) -> paddle.Tensor: dtype = dtype if dtype is not None else x.dtype x = paddle.cast(x, dtype) + if ivy.as_native_dtype(dtype) in [ + paddle.uint8, + paddle.int8, + paddle.float16, + paddle.bool, + ]: + x = paddle.cast(x, "float32") if not (exclusive or reverse): return paddle.cumsum(x, axis=axis).cast(dtype) elif exclusive and reverse: diff --git a/ivy/functional/backends/paddle/utility.py b/ivy/functional/backends/paddle/utility.py index 3c97fdca4302e..3bb7e8311ac4b 100644 --- a/ivy/functional/backends/paddle/utility.py +++ b/ivy/functional/backends/paddle/utility.py @@ -2,14 +2,8 @@ import paddle from typing import Union, Optional, Sequence import ivy.functional.backends.paddle as paddle_backend -from ivy.func_wrapper import with_supported_dtypes -from . import backend_version -@with_supported_dtypes( - {"2.5.1 and below": "bool"}, - backend_version, -) def all( x: paddle.Tensor, /, @@ -18,6 +12,7 @@ def all( keepdims: bool = False, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + x = x.cast("bool") if axis is None: axis = list(range(x.ndim)) if isinstance(axis, int): @@ -38,10 +33,6 @@ def all( return ret -@with_supported_dtypes( - {"2.5.1 and below": "bool"}, - backend_version, -) def any( x: paddle.Tensor, /, @@ -50,6 +41,7 @@ def any( keepdims: bool = False, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + x = x.cast("bool") if axis is None: axis = list(range(x.ndim)) if isinstance(axis, int): diff --git a/ivy/functional/backends/tensorflow/elementwise.py b/ivy/functional/backends/tensorflow/elementwise.py index f3fdb4dc8909b..a44d13c771f44 100644 --- a/ivy/functional/backends/tensorflow/elementwise.py +++ b/ivy/functional/backends/tensorflow/elementwise.py @@ -600,8 +600,8 @@ def positive( return tf.experimental.numpy.positive(x) -@with_supported_dtypes( - {"2.13.0 and below": ("float", "int32", "int64", "complex")}, +@with_unsupported_dtypes( + {"2.13.0 and below": ("uint8", "uint16", "uint32", "uint64", "float64")}, backend_version, ) def pow( @@ -620,6 +620,14 @@ def pow( ret = tf.experimental.numpy.power(x1, x2) return tf.where(x1 == 0, ivy.nan + ivy.nan * 1j, ret) x1, x2 = ivy.promote_types_of_inputs(x1, x2) + if isinstance(x1, tf.Tensor) and isinstance(x2, tf.Tensor): + if x1.dtype.is_unsigned or x2.dtype.is_unsigned: + promoted_type = tf.experimental.numpy.promote_types(x1.dtype, x2.dtype) + if x1.dtype.is_unsigned: + x1 = tf.cast(x1, tf.float64) + if x2.dtype.is_unsigned: + x2 = tf.cast(x2, tf.float64) + return tf.cast(tf.experimental.numpy.power(x1, x2), promoted_type) if ivy.is_int_dtype(x1) and ivy.any(x2 < 0): return tf.cast( tf.experimental.numpy.power(tf.cast(x1, tf.float32), x2), @@ -668,10 +676,6 @@ def round( return tf.cast(tf.round(x * factor) / factor_deno, ret_dtype) -@with_supported_dtypes( - {"2.13.0 and below": ("bfloat16", "float", "int32", "int64", "complex")}, - backend_version, -) def sign( x: Union[tf.Tensor, tf.Variable], /, @@ -679,6 +683,8 @@ def sign( np_variant: Optional[bool] = True, out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Union[tf.Tensor, tf.Variable]: + if x.dtype in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]: + return tf.cast(tf.math.sign(tf.cast(x, tf.float32)), x.dtype) if x.dtype in [tf.complex64, tf.complex128] and np_variant: real = tf.math.real(x) imag = tf.math.imag(x) diff --git a/ivy/functional/backends/tensorflow/experimental/elementwise.py b/ivy/functional/backends/tensorflow/experimental/elementwise.py index 367687664b3cd..96dba1c88a4dd 100644 --- a/ivy/functional/backends/tensorflow/experimental/elementwise.py +++ b/ivy/functional/backends/tensorflow/experimental/elementwise.py @@ -246,7 +246,6 @@ def _normalize_axis_tuple(axis: Union[int, list, tuple], ndim: int) -> Tuple[int return axis -@with_unsupported_dtypes({"2.13.0 and below": ("int", "unsigned")}, backend_version) def gradient( x: tf.Tensor, /, @@ -325,6 +324,9 @@ def gradient( slice3 = [slice(None)] * N slice4 = [slice(None)] * N + if x.dtype.is_integer: + x = x.astype(tf.experimental.numpy.float64) + for axis, ax_dx in zip(axes, dx): if x.shape[axis] < edge_order + 1: raise ValueError( diff --git a/ivy/functional/backends/tensorflow/experimental/layers.py b/ivy/functional/backends/tensorflow/experimental/layers.py index 6f776e26507b2..da6dc3578bf87 100644 --- a/ivy/functional/backends/tensorflow/experimental/layers.py +++ b/ivy/functional/backends/tensorflow/experimental/layers.py @@ -942,13 +942,14 @@ def _fft2_norm( raise ivy.utils.exceptions.IvyError(f"Unrecognized normalization mode {norm}") -@with_supported_dtypes({"2.13.0 and below": ("float32", "complex")}, backend_version) def trans_x_to_s( x: Union[tf.Tensor, tf.Variable], s: Sequence[int] = None, dim: Sequence[int] = (-2, -1), ) -> Union[tf.Tensor, tf.Variable]: """Change the shape of the input array x to the desired output shape s.""" + if x.dtype != tf.complex128 and x.dtype != tf.complex64: + x = tf.cast(x, tf.float32) x_shape = x.shape if dim == (-1, -2) or dim == (1, 0): s = (s[1], s[0]) diff --git a/ivy/functional/backends/tensorflow/experimental/linear_algebra.py b/ivy/functional/backends/tensorflow/experimental/linear_algebra.py index bd09aaf43034d..0732c302ca956 100644 --- a/ivy/functional/backends/tensorflow/experimental/linear_algebra.py +++ b/ivy/functional/backends/tensorflow/experimental/linear_algebra.py @@ -94,25 +94,23 @@ def matrix_exp( return tf.linalg.expm(x) -@with_supported_dtypes( - {"2.13.0 and below": ("float32", "float64", "complex")}, backend_version -) def eig( x: Union[tf.Tensor, tf.Variable], /, *, out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Tuple[tf.Tensor]: + if not ivy.dtype(x) in (ivy.float32, ivy.float64, ivy.complex64, ivy.complex128): + return tf.linalg.eig(tf.cast(x, tf.float64)) return tf.linalg.eig(x) -@with_supported_dtypes( - {"2.13.0 and below": ("float32", "float64", "complex")}, backend_version -) def eigvals( x: Union[tf.Tensor, tf.Variable], /, ) -> Union[tf.Tensor, tf.Variable]: + if not ivy.dtype(x) in (ivy.float32, ivy.float64, ivy.complex64, ivy.complex128): + return tf.linalg.eigvals(tf.cast(x, tf.float64)) return tf.linalg.eigvals(x) diff --git a/ivy/functional/backends/tensorflow/experimental/statistical.py b/ivy/functional/backends/tensorflow/experimental/statistical.py index 2525b76fbee02..7c7d45204ff3f 100644 --- a/ivy/functional/backends/tensorflow/experimental/statistical.py +++ b/ivy/functional/backends/tensorflow/experimental/statistical.py @@ -653,7 +653,6 @@ def cov( return tf.math.truediv(c, fact) -@with_supported_dtypes({"2.13.0 and below": ("float64", "int64")}, backend_version) def cummax( x: Union[tf.Tensor, tf.Variable], /, @@ -664,6 +663,13 @@ def cummax( dtype: Optional[tf.DType] = None, out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Tuple[tf.Tensor, tf.Tensor]: + if x.dtype in (tf.bool, tf.float16): + x = tf.cast(x, tf.float64) + elif x.dtype in (tf.int16, tf.int8, tf.uint8): + x = tf.cast(x, tf.int64) + elif x.dtype in (tf.complex128, tf.complex64): + x = tf.cast(tf.math.real(x), tf.float64) + if exclusive or reverse: if exclusive and reverse: x, indices = __find_cummax( diff --git a/ivy/functional/backends/tensorflow/manipulation.py b/ivy/functional/backends/tensorflow/manipulation.py index 43c8565e24cf1..7142dc34f1808 100644 --- a/ivy/functional/backends/tensorflow/manipulation.py +++ b/ivy/functional/backends/tensorflow/manipulation.py @@ -328,7 +328,7 @@ def swapaxes( return tf.transpose(x, config) -@with_unsupported_dtypes({"2.13.0 and below": ("complex", "bool")}, backend_version) +@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) def clip( x: Union[tf.Tensor, tf.Variable], x_min: Union[Number, tf.Tensor, tf.Variable], diff --git a/ivy/functional/backends/tensorflow/sorting.py b/ivy/functional/backends/tensorflow/sorting.py index 982c610e931d8..06403a1d358ca 100644 --- a/ivy/functional/backends/tensorflow/sorting.py +++ b/ivy/functional/backends/tensorflow/sorting.py @@ -8,7 +8,7 @@ from . import backend_version -@with_unsupported_dtypes({"2.13.0 and below": ("complex", "bool")}, backend_version) +@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) def argsort( x: Union[tf.Tensor, tf.Variable], /, @@ -20,11 +20,14 @@ def argsort( ) -> Union[tf.Tensor, tf.Variable]: direction = "DESCENDING" if descending else "ASCENDING" x = tf.convert_to_tensor(x) + is_bool = x.dtype.is_bool + if is_bool: + x = tf.cast(x, tf.int32) ret = tf.argsort(x, axis=axis, direction=direction, stable=stable) return tf.cast(ret, dtype=tf.int64) -@with_unsupported_dtypes({"2.13.0 and below": ("complex", "bool")}, backend_version) +@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) def sort( x: Union[tf.Tensor, tf.Variable], /, @@ -38,7 +41,12 @@ def sort( # currently it supports only quicksort (unstable) direction = "DESCENDING" if descending else "ASCENDING" x = tf.convert_to_tensor(x) + is_bool = x.dtype.is_bool + if is_bool: + x = tf.cast(x, tf.int32) ret = tf.sort(x, axis=axis, direction=direction) + if is_bool: + ret = tf.cast(ret, dtype=tf.bool) return ret diff --git a/ivy/functional/backends/torch/elementwise.py b/ivy/functional/backends/torch/elementwise.py index b255e393fe2b3..26b2f58919506 100644 --- a/ivy/functional/backends/torch/elementwise.py +++ b/ivy/functional/backends/torch/elementwise.py @@ -8,7 +8,6 @@ from ivy.func_wrapper import ( with_unsupported_dtypes, handle_numpy_arrays_in_specific_backend, - with_supported_dtypes, ) from ivy import promote_types_of_inputs from . import backend_version @@ -54,13 +53,15 @@ def bitwise_xor( bitwise_xor.support_native_out = True -@with_supported_dtypes({"2.0.1 and below": ("complex",)}, backend_version) def imag( val: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if val.dtype not in (torch.complex64, torch.complex128): + ret = torch.imag(val.to(torch.complex64)) + return ret.to(val.dtype) return torch.imag(val) diff --git a/ivy/functional/backends/torch/experimental/layers.py b/ivy/functional/backends/torch/experimental/layers.py index e49de1f5974f9..88f9f97a71efe 100644 --- a/ivy/functional/backends/torch/experimental/layers.py +++ b/ivy/functional/backends/torch/experimental/layers.py @@ -563,6 +563,8 @@ def dct( ) -> torch.tensor: if norm not in (None, "ortho"): raise ValueError("Norm must be either None or 'ortho'") + if x.dtype not in [torch.float32, torch.float64]: + x = x.type(torch.float32) if axis < 0: axis = axis + len(x.shape) if n is not None: diff --git a/ivy/functional/backends/torch/experimental/statistical.py b/ivy/functional/backends/torch/experimental/statistical.py index 1cb8fc484bad0..b2aee02bcd414 100644 --- a/ivy/functional/backends/torch/experimental/statistical.py +++ b/ivy/functional/backends/torch/experimental/statistical.py @@ -3,7 +3,7 @@ import torch # local -from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes +from ivy.func_wrapper import with_unsupported_dtypes from . import backend_version import ivy from ..statistical import _infer_dtype @@ -566,8 +566,8 @@ def cov( cov.support_native_out = False -@with_supported_dtypes( - {"2.0.1 and below": ("int64", "float64")}, +@with_unsupported_dtypes( + {"2.0.1 and below": ("uint8", "bfloat16", "float16")}, backend_version, ) def cummax( @@ -580,6 +580,13 @@ def cummax( dtype: Optional[torch.dtype] = None, out: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + if x.dtype in (torch.bool, torch.float16): + x = x.to(dtype=torch.float64) + elif x.dtype in (torch.int16, torch.int8, torch.uint8): + x = x.to(dtype=torch.int64) + elif x.dtype in (torch.complex64, torch.complex128): + x = x.real.to(dtype=torch.float64) + if exclusive or reverse: if exclusive and reverse: x1, x2 = torch.cummax(torch.flip(x, dims=(axis,)), axis) diff --git a/ivy/functional/backends/torch/searching.py b/ivy/functional/backends/torch/searching.py index 473cdf2f087ec..702100b3b5ab1 100644 --- a/ivy/functional/backends/torch/searching.py +++ b/ivy/functional/backends/torch/searching.py @@ -6,7 +6,7 @@ import ivy -from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes +from ivy.func_wrapper import with_unsupported_dtypes from . import backend_version # Array API Standard # @@ -95,7 +95,6 @@ def nonzero( return torch.stack(res, dim=1) -@with_supported_dtypes({"2.0.1 and below and below": ("bool",)}, backend_version) def where( condition: torch.Tensor, x1: Union[float, int, torch.Tensor], @@ -105,6 +104,8 @@ def where( out: Optional[torch.Tensor] = None, ) -> torch.Tensor: x1, x2 = ivy.promote_types_of_inputs(x1, x2) + if condition.dtype is not torch.bool: + condition = condition == 1.0 return ivy.astype(torch.where(condition, x1, x2), x1.dtype, copy=False) diff --git a/ivy/functional/backends/torch/statistical.py b/ivy/functional/backends/torch/statistical.py index 59c1be8505bb9..1e31b8854cbb2 100644 --- a/ivy/functional/backends/torch/statistical.py +++ b/ivy/functional/backends/torch/statistical.py @@ -7,7 +7,7 @@ # local import ivy from ivy.functional.ivy.statistical import _get_promoted_type_of_operands -from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes +from ivy.func_wrapper import with_unsupported_dtypes from . import backend_version # Array API Standard # @@ -66,7 +66,6 @@ def max( max.support_native_out = True -@with_supported_dtypes({"2.0.1 and below": ("float", "complex")}, backend_version) def mean( x: torch.Tensor, /, @@ -83,6 +82,10 @@ def mean( return ivy.inplace_update(out, x) else: return x + if "float" not in str(x.dtype) and "complex" not in str(x.dtype): # unsupported + return torch.mean(x.to(torch.float32), dim=axis, keepdim=keepdims, out=out).to( + x.dtype + ) return torch.mean(x, dim=axis, keepdim=keepdims, out=out)