Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix issues with tile function in Ivy Functional API #28402

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions ivy/functional/backends/paddle/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,18 @@ def astype(
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
dtype = ivy.as_native_dtype(dtype)

if copy and 0 in x.shape:
return paddle.empty(x.shape, dtype=dtype)

if x.dtype == dtype:
return x.clone() if copy else x
return x.clone().cast(dtype) if copy else x.cast(dtype)
if copy:
# Checking if the tensor is not empty
# As clone is not supported for empty tensors
if 0 in x.shape:
return paddle.to_tensor(
x,
dtype=dtype,
place=x.place,
stop_gradient=x.stop_gradient,
)
return x.clone() if x.dtype == dtype else x.clone().cast(dtype)
return x if x.dtype == dtype else x.cast(dtype)


def broadcast_arrays(*arrays: paddle.Tensor) -> List[paddle.Tensor]:
Expand Down
64 changes: 34 additions & 30 deletions ivy/functional/backends/paddle/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,38 +363,42 @@ def repeat(
def tile(
x: paddle.Tensor, /, repeats: Sequence[int], *, out: Optional[paddle.Tensor] = None
) -> paddle.Tensor:
if x.ndim >= 7:
repeats = (
repeats.numpy().tolist() if isinstance(repeats, paddle.Tensor) else repeats
)
new_shape = [*x.shape[:5], -1]
reshaped_tensor = paddle.reshape(x, new_shape)
new_repeats = repeats[:5] + [math.prod(repeats[5:])]
tiled_reshaped_tensor = tile(reshaped_tensor, new_repeats)
tiled_shape = tuple(s * r for s, r in zip(x.shape, repeats))
result = paddle.reshape(tiled_reshaped_tensor, tiled_shape)
return result
if ivy.min(repeats) == 0:
# This logic is to mimic other backends behaviour when a 0 in repeat
# is received since paddle doesn't natively support it
if len(repeats) < x.ndim:
repeats = repeats.tolist() if isinstance(repeats, paddle.Tensor) else list(repeats)
# Paddle doesn't natively support repeats containing zeros
if 0 in x.shape or (len(repeats) > 0 and min(repeats) == 0):
if x.ndim == 0:
shape = repeats
elif len(repeats) <= x.ndim:
shape = x.shape
shape[-len(repeats) :] = paddle_backend.multiply(
shape[-len(repeats) :], repeats
).tolist()
elif len(repeats) > x.ndim:
shape = (
repeats.tolist()
if isinstance(repeats, paddle.Tensor)
else list(repeats)
)
shape[-x.ndim - 1 :] = paddle_backend.multiply(
shape[-x.ndim - 1 :], repeats
).tolist()
shape[-len(repeats) :] = [
s * r for s, r in zip(shape[-len(repeats) :], repeats)
]
else:
shape = paddle_backend.multiply(x.shape, repeats).tolist()
return paddle.zeros(shape).cast(x.dtype)

shape = repeats.copy()
shape[-x.ndim :] = [s * r for r, s in zip(shape[-x.ndim :], x.shape)]
return paddle.empty(shape, dtype=x.dtype)
# Paddle doesn't natively support tensors containing more than 6 dimensions
if x.ndim > 6 or len(repeats) > 6:
if len(repeats) < x.ndim:
repeats = [1] * (x.ndim - len(repeats)) + repeats
elif len(repeats) > x.ndim:
shape = [1] * (len(repeats) - x.ndim) + x.shape
x = paddle.reshape(x, shape)
cur_shape = x.shape
cur_tensor = x
for i in range(0, x.ndim, 5):
size = 5 if i <= x.ndim - 5 else x.ndim - i
red_shape = [*cur_shape[:size], -1]
red_tensor = paddle.reshape(cur_tensor, red_shape)
red_repeats = [*repeats[i : i + size], 1]
tiled_red_tensor = paddle.tile(red_tensor, red_repeats)
perm = [size, *list(range(size))]
tiled_red_tensor = paddle.transpose(tiled_red_tensor, perm)
cur_shape = cur_shape[size:] + [
s * r for s, r in zip(cur_shape[:size], repeats[i : i + size])
]
cur_tensor = paddle.reshape(tiled_red_tensor, cur_shape)
return cur_tensor
return paddle.tile(x, repeats)


Expand Down
15 changes: 14 additions & 1 deletion ivy/functional/backends/paddle/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ def axis_condition(axis):
return ret.astype(ret_dtype)


def _calculate_reduced_shape(x, axis, keepdims):
if axis is None:
axis = tuple(range(len(x.shape)))
elif isinstance(axis, int):
axis = (axis,)
if keepdims:
return [1 if i in axis else x.shape[i] for i in range(len(x.shape))]
return [x.shape[i] for i in range(len(x.shape)) if i not in axis]


@with_supported_dtypes(
{"2.6.0 and below": ("bool", "complex", "float32", "float64")}, backend_version
)
Expand All @@ -174,7 +184,10 @@ def mean(
) -> paddle.Tensor:
if dtype is not None:
x = ivy.astype(x, dtype).to_native()
if paddle.is_complex(x):
if 0 in x.shape:
shape = _calculate_reduced_shape(x, axis, keepdims)
ret = paddle.full(shape, float("nan"))
elif paddle.is_complex(x):
ret = paddle.complex(
paddle.mean(x.real(), axis=axis, keepdim=keepdims),
paddle.mean(x.imag(), axis=axis, keepdim=keepdims),
Expand Down
25 changes: 11 additions & 14 deletions ivy/functional/backends/tensorflow/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,22 +312,19 @@ def tile(
*,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
if x.shape == ():
x = tf.reshape(x, (-1,))
if isinstance(repeats, Number):
repeats = [repeats]
if isinstance(repeats, tf.Tensor) and repeats.shape == ():
repeats = tf.reshape(repeats, (-1,))
# code to unify behaviour with numpy and torch
if len(x.shape) < len(repeats):
while len(x.shape) != len(repeats):
x = tf.expand_dims(x, 0)
elif len(x.shape) > len(repeats):
repeats = list(repeats)
while len(x.shape) != len(repeats):
repeats = [1] + repeats
# Unify behaviour with numpy and torch
# TODO remove the unifying behaviour code if tensorflow handles this
# https://github.com/tensorflow/tensorflow/issues/58002
if len(repeats) < len(x.shape):
repeats = (
repeats.numpy().tolist()
if isinstance(repeats, (tf.Tensor, tf.Variable))
else list(repeats)
)
repeats = [1] * (len(x.shape) - len(repeats)) + repeats
elif len(repeats) > len(x.shape):
shape = [1] * (len(repeats) - len(x.shape)) + x.shape
x = tf.reshape(x, shape)
return tf.tile(x, repeats)


Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/torch/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def tile(
) -> torch.Tensor:
if isinstance(repeats, torch.Tensor):
repeats = repeats.detach().cpu().numpy().tolist()
return x.repeat(repeats)
return torch.tile(x, repeats)


def constant_pad(
Expand Down
3 changes: 3 additions & 0 deletions ivy/functional/ivy/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def _remove_np_bfloat16(obj):
# from numpy arrays that have bfloat16 dtype using any extension because
# bfloat16 in not supported natively by numpy (as of version <=1.25)
if isinstance(obj, np.ndarray) and obj.dtype.name == "bfloat16":
# change dtype of empty array instead so that it doesn't lose its shape
if 0 in obj.shape:
return obj.astype(np.float32)
return obj.tolist()
return obj

Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/ivy/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2203,7 +2203,7 @@ def fill_diagonal(
steps = ivy.arange(0, end, step)
if isinstance(v, (ivy.Array, ivy.NativeArray)):
v = ivy.reshape(v, (-1,)).astype(a.dtype)
v = ivy.tile(v, int(ivy.ceil(len(steps) / v.shape[0])))[: len(steps)]
v = ivy.tile(v, (int(ivy.ceil(len(steps) / v.shape[0])),))[: len(steps)]
else:
v = ivy.repeat(v, len(steps))
ivy.scatter_flat(steps, v, size=a.shape[0], reduction="replace", out=a)
Expand Down
57 changes: 22 additions & 35 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,28 +335,16 @@ def _masked_fill_helper(draw):
@st.composite
def _repeat_helper(draw):
shape = draw(
helpers.get_shape(
min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10
)
st.shared(helpers.get_shape(min_dim_size=0, max_num_dims=8), key="value_shape")
)

input_dtype, x = draw(
helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
shape=shape,
)
)

MAX_NUMPY_DIMS = 32
repeats = draw(
st.lists(
st.integers(min_value=1, max_value=5),
st.integers(min_value=0, max_value=5),
min_size=len(shape),
max_size=MAX_NUMPY_DIMS,
max_size=8,
)
)
assume(np.prod(repeats) * np.prod(shape) <= 2**28)
return input_dtype, x, repeats
return repeats


@st.composite
Expand Down Expand Up @@ -11332,41 +11320,40 @@ def test_torch_remainder_(
class_tree=CLASS_TREE,
init_tree="torch.tensor",
method_name="repeat",
dtype_x_repeats=_repeat_helper(),
unpack_repeat=st.booleans(),
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
shape=st.shared(
helpers.get_shape(min_dim_size=0, max_num_dims=8), key="value_shape"
),
),
repeats=_repeat_helper(),
unpack_repeats=st.booleans(),
)
def test_torch_repeat(
dtype_x_repeats,
unpack_repeat,
dtype_and_x,
repeats,
unpack_repeats,
frontend_method_data,
init_flags,
method_flags,
frontend,
on_device,
backend_fw,
):
input_dtype, x, repeats = dtype_x_repeats

if backend_fw == "paddle":
# paddle only supports size of the shape of repeats
# to be less than or equal to 6
assume(len(repeats) <= 6)

repeat = {
"repeats": repeats,
}
if unpack_repeat:
method_flags.num_positional_args = len(repeat["repeats"]) + 1
for i, x_ in enumerate(repeat["repeats"]):
repeat[f"x{i}"] = x_
input_dtype, x = dtype_and_x
if unpack_repeats and len(repeats) > 0:
method_flags.num_positional_args = len(repeats)
method_kwargs = {f"x{i}": x_ for i, x_ in enumerate(repeats)}
else:
method_kwargs = {"repeats": repeats}
helpers.test_frontend_method(
init_input_dtypes=input_dtype,
backend_to_test=backend_fw,
init_all_as_kwargs_np={
"data": x[0],
},
method_input_dtypes=input_dtype,
method_all_as_kwargs_np=repeat,
method_all_as_kwargs_np=method_kwargs,
frontend_method_data=frontend_method_data,
init_flags=init_flags,
method_flags=method_flags,
Expand Down
22 changes: 9 additions & 13 deletions ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,29 +701,25 @@ def test_swapaxes(
@handle_test(
fn_tree="functional.ivy.tile",
dtype_value=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid", full=True),
shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"),
),
repeat=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("signed_integer"),
shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape").map(
lambda rep: (len(rep),)
available_dtypes=helpers.get_dtypes("valid"),
shape=st.shared(
helpers.get_shape(min_dim_size=0, max_num_dims=8), key="value_shape"
),
min_value=0,
max_value=10,
),
repeats=st.lists(st.integers(min_value=0, max_value=5), max_size=8),
)
def test_tile(*, dtype_value, repeat, test_flags, backend_fw, fn_name, on_device):
def test_tile(*, dtype_value, repeats, test_flags, backend_fw, fn_name, on_device):
dtype, value = dtype_value
repeat_dtype, repeat_list = repeat
# Empty tensors do not copy correctly in paddle
assume(backend_fw != "paddle" or 0 not in value[0].shape)
helpers.test_function(
input_dtypes=dtype + repeat_dtype,
input_dtypes=dtype,
test_flags=test_flags,
backend_to_test=backend_fw,
fn_name=fn_name,
on_device=on_device,
x=value[0],
repeats=repeat_list[0],
repeats=repeats,
rtol_=1e-2,
atol_=1e-2,
xs_grad_idxs=[[0, 0]],
Expand Down
Loading