Skip to content

Commit

Permalink
Add multi-output support to GP Latent (#7471)
Browse files Browse the repository at this point in the history
* Port 7226 and add dims support

* Fix typo in HSGP prior method

>
Co-authored-by: hchen19 <[email protected]>

* Fix HSGP test
  • Loading branch information
AlexAndorra authored Aug 30, 2024
1 parent d313012 commit 90f20a2
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 23 deletions.
74 changes: 61 additions & 13 deletions pymc/gp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,37 @@ class Latent(Base):
def __init__(self, *, mean_func=Zero(), cov_func=Constant(0.0)):
super().__init__(mean_func=mean_func, cov_func=cov_func)

def _build_prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
def _build_prior(
self, name, X, n_outputs=1, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs
):
mu = self.mean_func(X)
cov = stabilize(self.cov_func(X), jitter)
if reparameterize:
size = np.shape(X)[0]
v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=size, **kwargs)
f = pm.Deterministic(name, mu + cholesky(cov).dot(v), dims=kwargs.get("dims", None))
if "dims" in kwargs:
v = pm.Normal(
name + "_rotated_",
mu=0.0,
sigma=1.0,
**kwargs,
)

else:
size = (n_outputs, X.shape[0]) if n_outputs > 1 else X.shape[0]
v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=size, **kwargs)

f = pm.Deterministic(
name,
mu + cholesky(cov).dot(v.T).transpose(),
dims=kwargs.get("dims", None),
)

else:
f = pm.MvNormal(name, mu=mu, cov=cov, **kwargs)
mu_stack = pt.stack([mu] * n_outputs, axis=0) if n_outputs > 1 else mu
f = pm.MvNormal(name, mu=mu_stack, cov=cov, **kwargs)

return f

def prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
def prior(self, name, X, n_outputs=1, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
R"""
Returns the GP prior distribution evaluated over the input
locations `X`.
Expand All @@ -178,6 +197,12 @@ def prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
X : array-like
Function input values. If one-dimensional, must be a column
vector with shape `(n, 1)`.
n_outputs : int, default 1
Number of output GPs. If you're using `dims`, make sure their size
is equal to `(n_outputs, X.shape[0])`, i.e the number of output GPs
by the number of input points.
Example: `gp.prior("f", X=X, n_outputs=3, dims=("n_gps", "x_dim"))`,
where `len(n_gps) = 3` and `len(x_dim = X.shape[0]`.
reparameterize : bool, default True
Reparameterize the distribution by rotating the random
variable by the Cholesky factor of the covariance matrix.
Expand All @@ -188,10 +213,12 @@ def prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
Extra keyword arguments that are passed to :class:`~pymc.MvNormal`
distribution constructor.
"""
f = self._build_prior(name, X, n_outputs, reparameterize, jitter, **kwargs)

f = self._build_prior(name, X, reparameterize, jitter, **kwargs)
self.X = X
self.f = f
self.n_outputs = n_outputs

return f

def _get_given_vals(self, given):
Expand All @@ -212,12 +239,16 @@ def _get_given_vals(self, given):
def _build_conditional(self, Xnew, X, f, cov_total, mean_total, jitter):
Kxx = cov_total(X)
Kxs = self.cov_func(X, Xnew)

L = cholesky(stabilize(Kxx, jitter))
A = solve_lower(L, Kxs)
v = solve_lower(L, f - mean_total(X))
mu = self.mean_func(Xnew) + pt.dot(pt.transpose(A), v)
v = solve_lower(L, (f - mean_total(X)).T)

mu = self.mean_func(Xnew) + pt.dot(pt.transpose(A), v).T

Kss = self.cov_func(Xnew)
cov = Kss - pt.dot(pt.transpose(A), A)

return mu, cov

def conditional(self, name, Xnew, given=None, jitter=JITTER_DEFAULT, **kwargs):
Expand Down Expand Up @@ -255,7 +286,9 @@ def conditional(self, name, Xnew, given=None, jitter=JITTER_DEFAULT, **kwargs):
"""
givens = self._get_given_vals(given)
mu, cov = self._build_conditional(Xnew, *givens, jitter)
return pm.MvNormal(name, mu=mu, cov=cov, **kwargs)
f = pm.MvNormal(name, mu=mu, cov=cov, **kwargs)

return f


@conditioned_vars(["X", "f", "nu"])
Expand Down Expand Up @@ -447,7 +480,15 @@ def _build_marginal_likelihood(self, X, noise_func, jitter):
return mu, stabilize(cov, jitter)

def marginal_likelihood(
self, name, X, y, sigma=None, noise=None, jitter=JITTER_DEFAULT, is_observed=True, **kwargs
self,
name,
X,
y,
sigma=None,
noise=None,
jitter=JITTER_DEFAULT,
is_observed=True,
**kwargs,
):
R"""
Returns the marginal likelihood distribution, given the input
Expand Down Expand Up @@ -529,21 +570,28 @@ def _build_conditional(
Kxs = self.cov_func(X, Xnew)
Knx = noise_func(X)
rxx = y - mean_total(X)

L = cholesky(stabilize(Kxx, jitter) + Knx)
A = solve_lower(L, Kxs)
v = solve_lower(L, rxx)
mu = self.mean_func(Xnew) + pt.dot(pt.transpose(A), v)
v = solve_lower(L, rxx.T)
mu = self.mean_func(Xnew) + pt.dot(pt.transpose(A), v).T

if diag:
Kss = self.cov_func(Xnew, diag=True)
var = Kss - pt.sum(pt.square(A), 0)

if pred_noise:
var += noise_func(Xnew, diag=True)

return mu, var

else:
Kss = self.cov_func(Xnew)
cov = Kss - pt.dot(pt.transpose(A), A)

if pred_noise:
cov += noise_func(Xnew)

return mu, cov if pred_noise else stabilize(cov, jitter)

def conditional(
Expand Down
13 changes: 9 additions & 4 deletions pymc/gp/hsgp_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,18 +442,23 @@ def prior(
Dimension name for the GP random variable.
"""
phi, sqrt_psd = self.prior_linearized(X)
self._sqrt_psd = sqrt_psd

if self._parametrization == "noncentered":
self._beta = pm.Normal(
f"{name}_hsgp_coeffs_",
size=self._m_star - int(self._drop_first),
f"{name}_hsgp_coeffs",
size=self.n_basis_vectors - int(self._drop_first),
dims=hsgp_coeffs_dims,
)
self._sqrt_psd = sqrt_psd
f = self.mean_func(X) + phi @ (self._beta * self._sqrt_psd)

elif self._parametrization == "centered":
self._beta = pm.Normal(f"{name}_hsgp_coeffs_", sigma=sqrt_psd, dims=hsgp_coeffs_dims)
self._beta = pm.Normal(
f"{name}_hsgp_coeffs",
sigma=sqrt_psd,
size=self.n_basis_vectors - int(self._drop_first),
dims=hsgp_coeffs_dims,
)
f = self.mean_func(X) + phi @ self._beta

self.f = pm.Deterministic(name, f, dims=gp_dims)
Expand Down
75 changes: 70 additions & 5 deletions tests/gp/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import numpy.testing as npt
import pytensor.tensor as pt
import pytest

import pymc as pm
Expand Down Expand Up @@ -90,7 +91,12 @@ def test_raise_value_error(self):
with self.model:
with pytest.raises(ValueError):
self.gp.marginal_likelihood(
"like_both", X=self.x, Xu=self.xu, y=self.y, noise=self.sigma, sigma=self.sigma
"like_both",
X=self.x,
Xu=self.xu,
y=self.y,
noise=self.sigma,
sigma=self.sigma,
)

with pytest.raises(ValueError):
Expand Down Expand Up @@ -177,7 +183,11 @@ def setup_method(self):
pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3]),
pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3]),
)
self.means = (pm.gp.mean.Constant(0.5), pm.gp.mean.Constant(0.5), pm.gp.mean.Constant(0.5))
self.means = (
pm.gp.mean.Constant(0.5),
pm.gp.mean.Constant(0.5),
pm.gp.mean.Constant(0.5),
)

def testAdditiveMarginal(self):
with pm.Model() as model1:
Expand All @@ -199,7 +209,9 @@ def testAdditiveMarginal(self):

with model1:
fp1 = gpsum.conditional(
"fp1", self.Xnew, given={"X": self.X, "y": self.y, "sigma": self.noise, "gp": gpsum}
"fp1",
self.Xnew,
given={"X": self.X, "y": self.y, "sigma": self.noise, "gp": gpsum},
)
with model2:
fp2 = gptot.conditional("fp2", self.Xnew)
Expand Down Expand Up @@ -230,7 +242,9 @@ def testAdditiveMarginalApprox(self, approx):

with pm.Model() as model2:
gptot = pm.gp.MarginalApprox(
mean_func=reduce(add, self.means), cov_func=reduce(add, self.covs), approx=approx
mean_func=reduce(add, self.means),
cov_func=reduce(add, self.covs),
approx=approx,
)
fsum = gptot.marginal_likelihood("f", self.X, Xu, self.y, sigma=sigma)
model2_logp = model2.compile_logp()({})
Expand Down Expand Up @@ -352,6 +366,53 @@ def testLatent2(self):
latent_logp = model.compile_logp()({"f_rotated_": y_rotated, "p": self.pnew})
npt.assert_allclose(latent_logp, self.logp, atol=5)

def testLatentMultioutput(self):
n_outputs = 2
X = np.random.randn(20, 3)
y = np.random.randn(n_outputs, 20)
Xnew = np.random.randn(30, 3)
pnew = np.random.randn(n_outputs, 30)

with pm.Model() as latent_model:
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
mean_func = pm.gp.mean.Constant(0.5)
latent_gp = pm.gp.Latent(mean_func=mean_func, cov_func=cov_func)
latent_f = latent_gp.prior("f", X, n_outputs=n_outputs, reparameterize=True)
latent_p = latent_gp.conditional("p", Xnew)

with pm.Model() as marginal_model:
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
mean_func = pm.gp.mean.Constant(0.5)
marginal_gp = pm.gp.Marginal(mean_func=mean_func, cov_func=cov_func)
marginal_f = marginal_gp.marginal_likelihood("f", X, y, sigma=0.0)
marginal_p = marginal_gp.conditional("p", Xnew)

assert tuple(latent_f.shape.eval()) == tuple(marginal_f.shape.eval()) == y.shape
assert tuple(latent_p.shape.eval()) == tuple(marginal_p.shape.eval()) == pnew.shape

chol = np.linalg.cholesky(cov_func(X).eval())
v = np.linalg.solve(chol, (y - 0.5).T)
A = np.linalg.solve(chol, cov_func(X, Xnew).eval()).T
mu_cond = mean_func(Xnew).eval() + (A @ v).T
cov_cond = cov_func(Xnew, Xnew).eval() - A @ A.T

with pm.Model() as numpy_model:
numpy_p = pm.MvNormal.dist(mu=pt.as_tensor(mu_cond), cov=pt.as_tensor(cov_cond))

latent_rv_logp = pm.logp(latent_p, pnew)
marginal_rv_logp = pm.logp(marginal_p, pnew)
numpy_rv_logp = pm.logp(numpy_p, pnew)

assert (
latent_rv_logp.shape.eval()
== marginal_rv_logp.shape.eval()
== numpy_rv_logp.shape.eval()
)

npt.assert_allclose(latent_rv_logp.eval(), marginal_rv_logp.eval(), atol=5)
npt.assert_allclose(latent_rv_logp.eval(), numpy_rv_logp.eval(), atol=5)
npt.assert_allclose(marginal_rv_logp.eval(), numpy_rv_logp.eval(), atol=5)


class TestTP:
R"""
Expand Down Expand Up @@ -486,7 +547,11 @@ def setup_method(self):
self.X = cartesian(*self.Xs)
self.N = np.prod([len(X) for X in self.Xs])
self.y = np.random.randn(self.N) * 0.1
self.Xnews = (np.random.randn(5, 1), np.random.randn(5, 1), np.random.randn(5, 1))
self.Xnews = (
np.random.randn(5, 1),
np.random.randn(5, 1),
np.random.randn(5, 1),
)
self.Xnew = np.concatenate(self.Xnews, axis=1)
self.sigma = 0.2
self.pnew = np.random.randn(len(self.Xnew))
Expand Down
2 changes: 1 addition & 1 deletion tests/gp/test_hsgp_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_parametrization_drop_first(self, model, cov_func, X1, drop_first):
gp = pm.gp.HSGP(m=[n_basis], c=4.0, cov_func=cov_func, drop_first=drop_first)
gp.prior("f1", X1)

n_coeffs = model.f1_hsgp_coeffs_.type.shape[0]
n_coeffs = model.f1_hsgp_coeffs.type.shape[0]
if drop_first:
assert (
n_coeffs == n_basis - 1
Expand Down

0 comments on commit 90f20a2

Please sign in to comment.