diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index 2fc2e29088e22..2cf190bd2ac3d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -335,26 +335,16 @@ def _masked_fill_helper(draw): @st.composite def _repeat_helper(draw): shape = draw( - helpers.get_shape( - min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 - ) - ) - - input_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=shape, - ) + st.shared(helpers.get_shape(min_dim_size=0, max_num_dims=8), key="value_shape") ) - repeats = draw( st.lists( - st.integers(min_value=1, max_value=5), + st.integers(min_value=0, max_value=5), min_size=len(shape), - max_size=5, + max_size=8, ) ) - return input_dtype, x, repeats + return repeats @st.composite @@ -11330,11 +11320,18 @@ def test_torch_remainder_( class_tree=CLASS_TREE, init_tree="torch.tensor", method_name="repeat", - dtype_x_repeats=_repeat_helper(), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared( + helpers.get_shape(min_dim_size=0, max_num_dims=8), key="value_shape" + ), + ), + repeats=_repeat_helper(), unpack_repeats=st.booleans(), ) def test_torch_repeat( - dtype_x_repeats, + dtype_and_x, + repeats, unpack_repeats, frontend_method_data, init_flags, @@ -11343,8 +11340,8 @@ def test_torch_repeat( on_device, backend_fw, ): - input_dtype, x, repeats = dtype_x_repeats - if unpack_repeats: + input_dtype, x = dtype_and_x + if unpack_repeats and len(repeats) > 0: method_flags.num_positional_args = len(repeats) method_kwargs = {f"x{i}": x_ for i, x_ in enumerate(repeats)} else: