diff --git a/ivy/functional/backends/paddle/statistical.py b/ivy/functional/backends/paddle/statistical.py index fc036b4465474..468c178ec460a 100644 --- a/ivy/functional/backends/paddle/statistical.py +++ b/ivy/functional/backends/paddle/statistical.py @@ -107,7 +107,7 @@ def max( def _calculate_reduced_shape(x, axis, keepdims): if axis is None: axis = tuple(range(len(x.shape))) - elif type(axis) not in (tuple, list): + elif isinstance(axis, int): axis = (axis,) if keepdims: return [1 if i in axis else x.shape[i] for i in range(len(x.shape))] @@ -128,7 +128,7 @@ def mean( ret_dtype = x.dtype if 0 in x.shape: shape = _calculate_reduced_shape(x, axis, keepdims) - ret = paddle.empty(shape) + ret = paddle.full(shape, float("nan")) elif paddle.is_complex(x): ret = paddle.complex( paddle.mean(x.real(), axis=axis, keepdim=keepdims),