Skip to content

Commit

Permalink
Merge pull request #335 from calebweinreb/stable_psd_solve
Browse files Browse the repository at this point in the history
Stable psd solve
  • Loading branch information
slinderman authored Aug 9, 2023
2 parents 4641994 + 44d8c7d commit cadec55
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions dynamax/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jaxtyping import Array, Int
from scipy.optimize import linear_sum_assignment
from typing import Optional
from jax.scipy.linalg import cho_factor, cho_solve

def has_tpu():
try:
Expand Down Expand Up @@ -198,10 +199,12 @@ def find_permutation(
return perm


def psd_solve(A,b):
def psd_solve(A, b, diagonal_boost=1e-9):
"""A wrapper for coordinating the linalg solvers used in the library for psd matrices."""
A = A + 1e-6
return jnp.linalg.solve(A,b)
A = symmetrize(A) + diagonal_boost * jnp.eye(A.shape[-1])
L, lower = cho_factor(A, lower=True)
x = cho_solve((L, lower), b)
return x

def symmetrize(A):
"""Symmetrize one or more matrices."""
Expand Down

0 comments on commit cadec55

Please sign in to comment.