Skip to content

Commit

Permalink
fix: frontend torch linalg cholesky_ex (#28773)
Browse files Browse the repository at this point in the history
Co-authored-by: Jin Wang <[email protected]>
  • Loading branch information
Daniel4078 and Jin Wang authored Jun 27, 2024
1 parent 5665437 commit 5173264
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
9 changes: 7 additions & 2 deletions ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,23 @@ def cholesky(input, *, upper=False, out=None):


@to_ivy_arrays_and_back
@with_supported_dtypes(
{"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch"
)
def cholesky_ex(input, *, upper=False, check_errors=False, out=None):
try:
results = namedtuple("cholesky_ex", ['L', 'info'])
matrix = ivy.cholesky(input, upper=upper, out=out)
info = ivy.zeros(input.shape[:-2], dtype=ivy.int32)
return matrix, info
return results(matrix, info)
except RuntimeError as e:
if check_errors:
raise RuntimeError(e) from e
else:
results = namedtuple("cholesky_ex", ['L', 'info'])
matrix = input * math.nan
info = ivy.ones(input.shape[:-2], dtype=ivy.int32)
return matrix, info
return results(matrix, info)


@to_ivy_arrays_and_back
Expand Down
7 changes: 4 additions & 3 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def test_torch_cholesky(

@handle_frontend_test(
fn_tree="torch.linalg.cholesky_ex",
dtype_and_x=_get_dtype_and_matrix(square=True, batch=True),
dtype_and_x=_get_dtype_and_matrix(square=True),
upper=st.booleans(),
)
def test_torch_cholesky_ex(
Expand All @@ -350,8 +350,9 @@ def test_torch_cholesky_ex(
backend_fw,
):
dtype, x = dtype_and_x
x = np.matmul(x.T, x) + np.identity(x.shape[0]) # make symmetric positive-definite

x = np.asarray(x[0], dtype=dtype[0])
x = np.matmul(np.conjugate(x.T), x) + np.identity(x.shape[0], dtype=dtype[0])
# make symmetric positive-definite
helpers.test_frontend_function(
input_dtypes=dtype,
backend_to_test=backend_fw,
Expand Down

0 comments on commit 5173264

Please sign in to comment.