From e3c2ac436c71f80883c549da60057932f573cfa0 Mon Sep 17 00:00:00 2001 From: Ababio Date: Sat, 30 Sep 2023 17:08:12 -0400 Subject: [PATCH] feat: add diagonal to paddle frontend --- ivy/functional/frontends/paddle/math.py | 19 +++++++ .../test_frontends/test_paddle/test_math.py | 53 +++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/ivy/functional/frontends/paddle/math.py b/ivy/functional/frontends/paddle/math.py index 1d5cc060734e2..d8ddf7c2cbc21 100644 --- a/ivy/functional/frontends/paddle/math.py +++ b/ivy/functional/frontends/paddle/math.py @@ -175,6 +175,25 @@ def deg2rad(x, name=None): return ivy.deg2rad(x) +@with_supported_dtypes( + { + "2.5.1 and below": ( + "int32", + "int64", + "float64", + "complex128", + "float32", + "complex64", + "bool", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +def diagonal(x, offset=0, axis1=0, axis2=1, name=None): + return ivy.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) + + @with_supported_dtypes( {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_math.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_math.py index f8cddeca9b1ed..a962154de65cb 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_math.py @@ -14,6 +14,33 @@ # --------------- # +@st.composite +def _draw_paddle_diagonal(draw): + _dtype, _x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + max_num_dims=10, + min_dim_size=1, + max_dim_size=50, + ) + ) + + offset = (draw(helpers.ints(min_value=-10, max_value=50)),) + axes = ( + draw( + st.lists( + helpers.ints(min_value=-(len(_x)), max_value=len(_x)), + min_size=len(_x) + 1, + max_size=len(_x) + 1, + unique=True, + ).filter(lambda axes: axes[0] % 2 != axes[1] % 2) + ), + ) + + return _dtype, _x[0], offset[0], axes[0] + + @st.composite def _test_paddle_take_helper(draw): mode = draw(st.sampled_from(["raise", "clip", "wrap"])) @@ -680,6 +707,32 @@ def test_paddle_deg2rad( ) +# diagonal +@handle_frontend_test(fn_tree="paddle.diagonal", data=_draw_paddle_diagonal()) +def test_paddle_diagonal( + *, + data, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + _dtype, _x, offset, axes = data + helpers.test_frontend_function( + input_dtypes=_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=_x, + offset=offset, + axis1=axes[0], + axis2=axes[1], + ) + + # diff @handle_frontend_test( fn_tree="paddle.diff",