Skip to content

Commit

Permalink
Add pinv
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Jul 3, 2021
1 parent 2b3ca95 commit 1560011
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ kron(a, b)
svd(a, compute_uv=True)
solve(a, b)
inv(a)
pinv(a)
det(a)
logdet(a)
expm(a)
Expand Down
19 changes: 19 additions & 0 deletions lab/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"svd",
"solve",
"inv",
"pinv",
"det",
"logdet",
"expm",
Expand Down Expand Up @@ -173,6 +174,24 @@ def inv(a): # pragma: no cover
"""


@dispatch
def pinv(a): # pragma: no cover
"""Compute the pseudo-inverse of `a`.
Args:
a (tensor): Matrix to compute pseudo-inverse of.
Returns:
tensor: Pseudo-inverse of `a`.
"""
if B.shape(a, -2) >= B.shape(a, -1):
chol = B.chol(B.matmul(a, a, tr_a=True))
return B.cholsolve(chol, B.transpose(a))
else:
chol = B.chol(B.matmul(a, a, tr_b=True))
return B.transpose(B.cholsolve(chol, a))


@dispatch
@abstract()
def det(a): # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = backends
version = 1.3.4
version = 1.3.5
author = Wessel Bruinsma
author_email = [email protected]
description = A generic interface for linear algebra backends
Expand Down
7 changes: 7 additions & 0 deletions tests/test_linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def test_inv(check_lazy_shapes):
check_function(B.inv, (Matrix(4, 3, 3),))


def test_pinv(check_lazy_shapes):
a = Matrix(4, 6, 3).np()
assert B.shape(B.pinv(a)) == (4, 3, 6)
assert B.shape(B.pinv(B.pinv(a))) == (4, 6, 3)
approx(a, B.pinv(B.pinv(a)))


def test_det(check_lazy_shapes):
check_function(B.det, (Matrix(),))
check_function(B.det, (Matrix(4, 3, 3),))
Expand Down

0 comments on commit 1560011

Please sign in to comment.