From 3882b0aefa395bc5ef355d6d7b166cb5cfd26626 Mon Sep 17 00:00:00 2001 From: vaibhav Date: Thu, 14 Sep 2023 01:24:07 +0530 Subject: [PATCH] adding roots function for jax frontend --- ivy/functional/frontends/jax/numpy/mathematical_functions.py | 2 +- .../test_jax/test_numpy/test_mathematical_functions.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index 062ebfcff028d..29420acd4280f 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -669,7 +669,7 @@ def roots(p): # finding the eigenvalue of comp_matrix which is also the roots roots = ivy.eigvals(comp_matrix) - return roots + return ivy.flip(roots) @to_ivy_arrays_and_back diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index d6b80dbf7731a..fca0a3f957a2e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -2826,7 +2826,7 @@ def test_jax_remainder( num_arrays=1, min_num_dims=1, max_num_dims=1, - min_dim_size=0, + min_dim_size=3, ), ) def test_jax_roots( @@ -2839,7 +2839,6 @@ def test_jax_roots( backend_fw, ): input_dtype, x = dtype_and_x - assume("float16" not in input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, test_flags=test_flags, @@ -2847,7 +2846,7 @@ def test_jax_roots( backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - p=x, + p=x[0], atol=1e-05, rtol=1e-03, )