-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Autodiff likelihoods from lsbi #43
Comments
The other option here is to have analytic gradients (and hessians) -- I don't know if this would be less flexible/faster or slower? |
Good point! probably better and fits the ethos more, I will say this is not expensive and relatively easy to modify, so until we know what we actually want to optimize/sample, this is probably sufficient. Below example fitting a model matrix from a single joint observation from lsbi.model import MixtureModel, LinearModel
from jax.scipy.stats import multivariate_normal
import jax.numpy as jnp
import numpy as np
from jax.scipy.special import logsumexp
import anesthetic as ns
import matplotlib.pyplot as plt
d = 100
t = 5
k = 3
C = np.eye(d) * 50
model = LinearModel(M=np.random.randn(d, t))
# model = MixtureModel(M=np.random.randn(k, d, t))
true_theta, true_data = np.split(model.joint().rvs(), [t], axis=-1)
def log_prob(theta_m):
#evidence
# mu = model.m + jnp.einsum(
# "...ja,...a->...j", theta_m, true_theta * jnp.ones(model.n)
# )
# Σ = model._C + jnp.einsum(
# "...ja,...ab,...kb->...jk", theta_m, model._Σ, theta_m
# )
# return multivariate_normal.logpdf(true_data, mean=mu, cov=Σ)
#likelihood
mu = model.m + jnp.einsum(
"...ja,...a->...j", theta_m, true_theta * jnp.ones(model.n)
)
return - multivariate_normal.logpdf(true_data, mean=mu, cov=model._C)
from jax import random
from jax import vmap, value_and_grad, jit
import optax
from jaxopt import LBFGS
rng = random.PRNGKey(0)
theta_m_samples = random.normal(rng, (d, t))
# np_log_prob = model.likelihood(theta_m_samples).logpdf(true_data)
jax_log_prob = log_prob(theta_m_samples)
# value, grad = vmap(value_and_grad(log_prob))(theta_m_samples)
theta_m = random.normal(rng, (d, t))
steps = 1000
# optimizer = optax.adam(1)
# opt_state = optimizer.init(theta_m)
solver = LBFGS(jit(log_prob), maxiter=steps)
# losses = []
# for i in range(steps):
# value, grad = jit(value_and_grad(log_prob))(theta_m)
# updates, optimizer_state = optimizer.update(grad, opt_state)
# theta_m = optax.apply_updates(theta_m, updates)
# losses.append(value)
# print(value)
res = solver.run(theta_m)
surrogate_model = LinearModel(M=res[0])
a = ns.MCMCSamples(surrogate_model.posterior(true_data).rvs(500)).plot_2d(figsize=(6,6), label = "Fitted Surrogate Posterior")
ns.MCMCSamples(model.posterior(true_data).rvs(500)).plot_2d(a, label = "True Posterior")
a.iloc[-1, 0].legend(
loc="lower center",
bbox_to_anchor=(len(a) / 2, len(a)),
)
plt.savefig("model_opt.pdf") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
It would be useful for testing numerical inference algorithms to have differentiable likelihoods in lsbi. In theory I think the whole package can swap to jax, however things like rng are quite different and would require some excavation, links to #41.
The basic thing one needs is the ability to furnish the distributions with a jax log_prob function. The most useful would be the likelihood, this can be done fairly simply below.
and for basic mixtures
Not sure if this can be elegantly integrated but I will put this here for now as potentially useful for other projects
nb: correct weighting for mixtures with non trivial weights is wrong here, to be fixed later
The text was updated successfully, but these errors were encountered: