diff --git a/ivy/functional/backends/paddle/data_type.py b/ivy/functional/backends/paddle/data_type.py index 1a67457c36079..1dfcf49e4104f 100644 --- a/ivy/functional/backends/paddle/data_type.py +++ b/ivy/functional/backends/paddle/data_type.py @@ -196,7 +196,7 @@ def iinfo(type: Union[paddle.dtype, str, paddle.Tensor], /) -> Iinfo: def result_type(*arrays_and_dtypes: Union[paddle.Tensor, paddle.dtype]) -> ivy.Dtype: - return ivy.promote_types(arrays_and_dtypes[0].dtype, arrays_and_dtypes[1].dtype) + return ivy.promote_types_of_inputs(*arrays_and_dtypes)[0].dtype # Extra # diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_utilities.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_utilities.py index f48d51d3e3487..8fcae08acbf55 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_utilities.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_utilities.py @@ -74,30 +74,65 @@ def test_torch_bincount( ) +@st.composite +def _elemwise_helper(draw): + value_strategy = st.one_of( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), + st.integers(min_value=-10000, max_value=10000), + st.floats(min_value=-10000, max_value=10000), + ) + + dtype_and_x1 = draw(value_strategy) + if isinstance(dtype_and_x1, tuple): + dtype1 = dtype_and_x1[0] + x1 = dtype_and_x1[1][0] + else: + dtype1 = [] + x1 = dtype_and_x1 + + dtype_and_x2 = draw(value_strategy) + if isinstance(dtype_and_x2, tuple): + dtype2 = dtype_and_x2[0] + x2 = dtype_and_x2[1][0] + else: + dtype2 = [] + x2 = dtype_and_x2 + + num_pos_args = None + if not dtype1 and not dtype2: + num_pos_args = 2 + elif not dtype1: + x1, x2 = x2, x1 + input_dtypes = dtype1 + dtype2 + + return x1, x2, input_dtypes, num_pos_args + + @handle_frontend_test( fn_tree="torch.result_type", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - ), + dtypes_and_xs=_elemwise_helper(), test_with_out=st.just(False), ) def test_torch_result_type( - dtype_and_x, + dtypes_and_xs, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + x1, x2, input_dtypes, num_pos_args = dtypes_and_xs + if num_pos_args is not None: + test_flags.num_positional_args = num_pos_args helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - tensor=x[0], - other=x[1], + tensor=x1, + other=x2, )