Skip to content

Commit

Permalink
added multi_dot for padddle backend and fixed failing test for torch …
Browse files Browse the repository at this point in the history
…backend
  • Loading branch information
akshatvishu committed Aug 15, 2023
1 parent 0fb73fa commit 59b9f74
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 3 deletions.
30 changes: 29 additions & 1 deletion ivy/functional/backends/paddle/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

# local
from ivy.functional.ivy.experimental.linear_algebra import _check_valid_dimension_size
from ivy.func_wrapper import with_unsupported_device_and_dtypes
from ivy.func_wrapper import (
with_unsupported_device_and_dtypes,
with_supported_device_and_dtypes,
)
from ivy.utils.exceptions import IvyNotImplementedException
from .. import backend_version

Expand Down Expand Up @@ -117,3 +120,28 @@ def dot(


dot.support_native_out = True


@with_supported_device_and_dtypes(
{
"2.5.1 and below": {
"cpu": (
"float32",
"float64",
),
"gpu": (
"float16",
"float32",
"float64",
),
}
},
backend_version,
)
def multi_dot(
x: paddle.Tensor,
/,
*,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
return paddle.linalg.multi_dot(x)
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def adjoint(
return torch.adjoint(x).resolve_conj()


@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version)
def multi_dot(
x: Sequence[torch.Tensor],
/,
Expand Down
2 changes: 1 addition & 1 deletion ivy_tests/array_api_testing/test_array_api
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,12 @@ def _generate_multi_dot_dtype_and_arrays(draw):
dtype_x=_generate_multi_dot_dtype_and_arrays(),
test_gradients=st.just(False),
)
def test_multi_dot(dtype_x, test_flags, backend_fw, fn_name):
def test_multi_dot(dtype_x, test_flags, backend_fw, fn_name, on_device):
dtype, x = dtype_x
helpers.test_function(
input_dtypes=dtype,
test_flags=test_flags,
on_device=on_device,
backend_to_test=backend_fw,
fn_name=fn_name,
test_values=True,
Expand Down

0 comments on commit 59b9f74

Please sign in to comment.