Skip to content

Commit

Permalink
We don't need Cholesky when we already have the eigendecomposition.
Browse files Browse the repository at this point in the history
  • Loading branch information
Si Yu How authored and Si Yu How committed Feb 3, 2023
1 parent bb973da commit 7977222
Showing 1 changed file with 38 additions and 71 deletions.
109 changes: 38 additions & 71 deletions hamiltorch/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def collect_gradients(log_prob, params, pass_grad = None):
return params


def fisher(params, log_prob_func=None, jitter=None, normalizing_const=1., softabs_const=1e6, metric=Metric.SOFTABS):
"""Called upon when using RMHMC. Returns the Fisher Information Matrix or Metric (often referred to as G).
def fisher_eigenh(params, log_prob_func=None, jitter=None, normalizing_const=1., softabs_const=1e6, metric=Metric.SOFTABS):
"""Called upon when using RMHMC. Returns the eigendecomposition of Fisher Information Matrix or Metric (often referred to as G).
Parameters
----------
Expand All @@ -84,67 +84,42 @@ def fisher(params, log_prob_func=None, jitter=None, normalizing_const=1., softab
Returns
-------
fish : torch.tensor
Fisher Matrix: shape (D,D).
abs_eigenvalues : torch.tensor or None
Absolute value of the eigenvalues, or None when not using softabs.
abs_eigenvalues : torch.tensor
Absolute value of the eigenvalues.
eigenvectors : torch.tensor or None
Eigenvectors of Fisher Matrix: shape (D,D), or None when it is the identity matrix.
"""

log_prob = log_prob_func(params)
if util.has_nan_or_inf(log_prob):
print('Invalid log_prob: {}, params: {}'.format(log_prob, params))
raise util.LogProbError()

if metric == Metric.JACOBIAN_DIAG:
# raise NotImplementedError()
# import pdb; pdb.set_trace()
jac = util.jacobian(log_prob, params, create_graph=True, return_inputs=False)
jac = torch.cat([j.flatten() for j in jac])
# util.flatten(jac).view(1,-1)
fish = torch.matmul(jac.view(-1,1),jac.view(1,-1)).diag().diag()#/ normalizing_const #.diag().diag() / normalizing_const
else:
hess = torch.autograd.functional.hessian(log_prob_func, params, create_graph=True)
fish = - hess #/ normalizing_const
if util.has_nan_or_inf(fish):
print('Invalid hessian: {}, params: {}'.format(fish, params))
raise util.LogProbError()
if jitter is not None:
params_n_elements = fish.shape[0]
fish += (torch.eye(params_n_elements) * torch.rand(params_n_elements) * jitter).to(fish.device)
if metric is Metric.JACOBIAN_DIAG:
return fish, None
if util.has_nan_or_inf(jac):
print('Invalid jacobian: {}, params: {}'.format(jac, params))
eigenvalues, eigenvectors = -jac * jac, None
if jitter is not None:
eigenvalues = eigenvalues - jitter * torch.rand_like(eigenvalues)
elif metric == Metric.SOFTABS:
eigenvalues, eigenvectors = torch.linalg.eigh(fish, UPLO='L')
abs_eigenvalues = (1./torch.tanh(softabs_const * eigenvalues)) * eigenvalues
fish = torch.matmul(eigenvectors, torch.matmul(abs_eigenvalues.diag(), eigenvectors.t()))
return fish, abs_eigenvalues
hess = torch.autograd.functional.hessian(log_prob_func, params, create_graph=True)
if util.has_nan_or_inf(hess):
print('Invalid hessian: {}, params: {}'.format(hess, params))
raise util.LogProbError()
if jitter is not None:
hess = hess - jitter * torch.diag(torch.rand(hess.shape[0], device=hess.device))
eigenvalues, eigenvectors = torch.linalg.eigh(hess, UPLO='L')
else:
# if metric == Metric.JACOBIAN:
# jac = jacobian(log_prob, params, create_graph=True)
# fish = torch.matmul(jac.t(),jac) / normalizing_const
raise ValueError('Unknown metric: {}'.format(metric))


def cholesky_inverse(fish, momentum):
"""Performs the inverse of a matrix, using the cholesky inverse (with the vector).
Parameters
----------
fish : torch.tensor
Square matrix to be inverted: shape (D,D).
momentum : torch.tensor
Vector of shape (D,).
Returns
-------
torch.tensor
Returns the inverted matrix multiplied by the vector.
"""
lower = torch.linalg.cholesky(fish)
y = torch.linalg.solve_triangular(lower, momentum.view(-1, 1), upper=False, unitriangular=False)
fish_inv_p = torch.linalg.solve_triangular(lower.t(), y, upper=True, unitriangular=False)
return fish_inv_p
abs_eigenvalues = eigenvalues / torch.tanh(softabs_const * eigenvalues)
if util.has_nan_or_inf(abs_eigenvalues):
print('Invalid abs_eigenvalues: {}, params: {}'.format(abs_eigenvalues, params))
return abs_eigenvalues, eigenvectors


def gibbs(params, sampler=Sampler.HMC, log_prob_func=None, jitter=None, normalizing_const=1., softabs_const=None, mass=None, metric=Metric.SOFTABS):
Expand Down Expand Up @@ -179,7 +154,12 @@ def gibbs(params, sampler=Sampler.HMC, log_prob_func=None, jitter=None, normaliz
"""

if sampler == Sampler.RMHMC:
dist = torch.distributions.MultivariateNormal(torch.zeros_like(params), fisher(params, log_prob_func, jitter, normalizing_const, softabs_const, metric)[0])
abs_eigenvalues, eigenvectors = fisher_eigenh(params, log_prob_func, jitter=jitter, normalizing_const=normalizing_const, softabs_const=softabs_const, metric=metric)
dist = torch.distributions.Normal(torch.zeros_like(params), torch.sqrt(abs_eigenvalues))
v = dist.sample()
if eigenvectors is not None:
v = torch.mv(eigenvectors, v)
return v
elif mass is None:
dist = torch.distributions.Normal(torch.zeros_like(params), torch.ones_like(params))
else:
Expand Down Expand Up @@ -706,24 +686,11 @@ def rm_hamiltonian(params, momentum, log_prob_func, jitter, normalizing_const, s
"""

log_prob = log_prob_func(params)
abs_eigenvalues, eigenvectors = fisher_eigenh(params, log_prob_func, jitter=jitter, normalizing_const=normalizing_const, softabs_const=softabs_const, metric=metric)

fish, abs_eigenvalues = fisher(params, log_prob_func, jitter=jitter, normalizing_const=normalizing_const, softabs_const=softabs_const, metric=metric)

if abs_eigenvalues is not None:
if util.has_nan_or_inf(fish) or util.has_nan_or_inf(abs_eigenvalues):
print('Invalid Fisher: {} , abs_eigenvalues: {}, params: {}'.format(fish, abs_eigenvalues, params))
raise util.LogProbError()
else:
if util.has_nan_or_inf(fish):
print('Invalid Fisher: {}, params: {}'.format(fish, params))
raise util.LogProbError()

if metric == Metric.SOFTABS:
log_det_abs = abs_eigenvalues.log().sum()
else:
log_det_abs = torch.slogdet(fish)[1]
fish_inverse_momentum = cholesky_inverse(fish, momentum)
quadratic_term = torch.matmul(momentum.view(1, -1), fish_inverse_momentum)
log_det_abs = abs_eigenvalues.log().sum()
rotated_momentum = torch.mv(eigenvectors.T, momentum) if eigenvectors is not None else momentum
quadratic_term = torch.dot(rotated_momentum, rotated_momentum / abs_eigenvalues)
hamiltonian = - log_prob + 0.5 * log_det_abs + 0.5 * quadratic_term
if util.has_nan_or_inf(hamiltonian):
print('Invalid hamiltonian, log_prob: {}, params: {}, momentum: {}'.format(log_prob, params, momentum))
Expand Down Expand Up @@ -825,12 +792,12 @@ def hamiltonian(params, momentum, log_prob_func, jitter=0.01, normalizing_const=
hamiltonian = HA + HB + explicit_binding_const * HC
elif sampler == Sampler.RMHMC and integrator == Integrator.S3: # CURRENTLY ASSUMING DIAGONAL
log_prob = log_prob_func(params)
fish, abs_eigenvalues = fisher(params, log_prob_func, jitter=jitter, normalizing_const=normalizing_const, softabs_const=softabs_const, metric=metric)
fish_inverse_momentum = cholesky_inverse(fish, momentum)
quadratic_term = torch.matmul(momentum.view(1, -1), fish_inverse_momentum)
# print((momentum ** 2 * fish.diag() ** -1).sum() - quadratic_term)
hamiltonian = - log_prob + 0.5 * quadratic_term + ham_func(params)
abs_eigenvalues, eigenvectors = fisher_eigenh(params, log_prob_func, jitter=jitter, normalizing_const=normalizing_const, softabs_const=softabs_const, metric=metric)

log_det_abs = abs_eigenvalues.log().sum()
rotated_momentum = torch.mv(eigenvectors.T, momentum) if eigenvectors is not None else momentum
quadratic_term = torch.dot(rotated_momentum, rotated_momentum / abs_eigenvalues)
hamiltonian = - log_prob + 0.5 * log_det_abs + 0.5 * quadratic_term + ham_func(params)
if util.has_nan_or_inf(hamiltonian):
print('Invalid hamiltonian, log_prob: {}, params: {}, momentum: {}'.format(log_prob, params, momentum))
raise util.LogProbError()
Expand Down

0 comments on commit 7977222

Please sign in to comment.