Skip to content

Commit

Permalink
Marginalapprox fix (#6076)
Browse files Browse the repository at this point in the history
* switch from using DensityDist to using Potential

* increase tolerance on flaky tests, add test using find_MAP

* refactor MarginalApprox tests

* run precommit

* address comments, pass approx arg correctly, improve docstrings

* fix comments, make pass jitter through correctly, get rid of is_observed arg
  • Loading branch information
bwengals authored Sep 1, 2022
1 parent 0b191ad commit 8f02bea
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 92 deletions.
48 changes: 4 additions & 44 deletions pymc/gp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,18 +685,13 @@ def __init__(self, approx="VFE", *, mean_func=Zero(), cov_func=Constant(0.0)):
super().__init__(mean_func=mean_func, cov_func=cov_func)

def __add__(self, other):
# new_gp will default to FITC approx
new_gp = super().__add__(other)
# make sure new gp has correct approx
if not self.approx == other.approx:
raise TypeError("Cannot add GPs with different approximations")
new_gp.approx = self.approx
return new_gp

# Use y as first argument, so that we can use functools.partial
# in marginal_likelihood instead of lambda. This makes pickling
# possible.
def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
def _build_marginal_likelihood_loglik(self, y, X, Xu, sigma, jitter):
sigma2 = at.square(sigma)
Kuu = self.cov_func(Xu)
Kuf = self.cov_func(Xu, X)
Expand Down Expand Up @@ -725,9 +720,7 @@ def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
quadratic = 0.5 * (at.dot(r, r_l) - at.dot(c, c))
return -1.0 * (constant + logdet + quadratic + trace)

def marginal_likelihood(
self, name, X, Xu, y, noise=None, is_observed=True, jitter=JITTER_DEFAULT, **kwargs
):
def marginal_likelihood(self, name, X, Xu, y, noise=None, jitter=JITTER_DEFAULT, **kwargs):
R"""
Returns the approximate marginal likelihood distribution, given the input
locations `X`, inducing point locations `Xu`, data `y`, and white noise
Expand All @@ -747,9 +740,6 @@ def marginal_likelihood(
noise. Must have shape `(n, )`.
noise: scalar, Variable
Standard deviation of the Gaussian noise.
is_observed: bool
Whether to set `y` as an `observed` variable in the `model`.
Default is `True`.
jitter: scalar
A small correction added to the diagonal of positive semi-definite
covariance matrices to ensure numerical stability.
Expand All @@ -767,38 +757,8 @@ def marginal_likelihood(
else:
self.sigma = noise

if is_observed:
return pm.DensityDist(
name,
X,
Xu,
self.sigma,
jitter,
logp=self._build_marginal_likelihood_logp,
observed=y,
ndims_params=[2, 2, 0],
size=X.shape[0],
**kwargs,
)
else:
warnings.warn(
"The 'is_observed' argument has been deprecated. If the GP is "
"unobserved use gp.Latent instead.",
FutureWarning,
)
return pm.DensityDist(
name,
X,
Xu,
self.sigma,
jitter,
logp=self._build_marginal_likelihood_logp,
observed=y,
ndims_params=[2, 2, 0],
# ndim_supp=1,
size=X.shape[0],
**kwargs,
)
approx_loglik = self._build_marginal_likelihood_loglik(y, X, Xu, noise, jitter)
pm.Potential(f"marginalapprox_loglik_{name}", approx_loglik, **kwargs)

def _build_conditional(
self, Xnew, pred_noise, diag, X, Xu, y, sigma, cov_total, mean_total, jitter
Expand Down
104 changes: 56 additions & 48 deletions pymc/tests/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,63 +846,71 @@ def testLatent2(self):

class TestMarginalVsMarginalApprox:
R"""
Compare logp of models Marginal and MarginalApprox.
Should be nearly equal when inducing points are same as inputs.
Compare test fits of models Marginal and MarginalApprox.
"""

def setup_method(self):
X = np.random.randn(50, 3)
y = np.random.randn(50)
Xnew = np.random.randn(60, 3)
pnew = np.random.randn(60)
with pm.Model() as model:
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
mean_func = pm.gp.mean.Constant(0.5)
gp = pm.gp.Marginal(mean_func=mean_func, cov_func=cov_func)
sigma = 0.1
f = gp.marginal_likelihood("f", X, y, noise=sigma)
p = gp.conditional("p", Xnew)
self.logp = model.compile_logp()({"p": pnew})
self.X = X
self.Xnew = Xnew
self.y = y
self.sigma = sigma
self.pnew = pnew
self.gp = gp
self.sigma = 0.1
self.x = np.linspace(-5, 5, 30)
self.y = np.random.normal(0.25 * self.x, self.sigma)
with pm.Model() as model:
cov_func = pm.gp.cov.Linear(1, c=0.0)
c = pm.Normal("c", mu=20.0, sigma=100.0) # far from true value
mean_func = pm.gp.mean.Constant(c)
self.gp = pm.gp.Marginal(mean_func=mean_func, cov_func=cov_func)
sigma = pm.HalfNormal("sigma", sigma=100)
self.gp.marginal_likelihood("lik", self.x[:, None], self.y, sigma)
self.map_full = pm.find_MAP(method="bfgs") # bfgs seems to work much better than lbfgsb

self.x_new = np.linspace(-6, 6, 20)

# Include additive Gaussian noise, return diagonal of predicted covariance matrix
with model:
self.pred_mu, self.pred_var = self.gp.predict(
self.x_new[:, None], point=self.map_full, pred_noise=True, diag=True
)

@pytest.mark.parametrize("approx", ["FITC", "VFE", "DTC"])
def testApproximations(self, approx):
with pm.Model() as model:
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
mean_func = pm.gp.mean.Constant(0.5)
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx=approx)
f = gp.marginal_likelihood("f", self.X, self.X, self.y, self.sigma)
p = gp.conditional("p", self.Xnew)
approx_logp = model.compile_logp()({"p": self.pnew})
npt.assert_allclose(approx_logp, self.logp, atol=0, rtol=1e-2)
# Dont include additive Gaussian noise, return full predicted covariance matrix
with model:
self.pred_mu, self.pred_covar = self.gp.predict(
self.x_new[:, None], point=self.map_full, pred_noise=False, diag=False
)

@pytest.mark.parametrize("approx", ["FITC", "VFE", "DTC"])
def testPredictVar(self, approx):
def test_fits_and_preds(self, approx):
"""Get MAP estimate for GP approximation, compare results and predictions to what's returned
by an unapproximated GP. The tolerances are fairly wide, but narrow relative to initial
values of the unknown parameters.
"""

with pm.Model() as model:
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
mean_func = pm.gp.mean.Constant(0.5)
cov_func = pm.gp.cov.Linear(1, c=0.0)
c = pm.Normal("c", mu=20.0, sigma=100.0, initval=-500.0)
mean_func = pm.gp.mean.Constant(c)
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx=approx)
f = gp.marginal_likelihood("f", self.X, self.X, self.y, self.sigma)
mu1, var1 = self.gp.predict(self.Xnew, diag=True)
mu2, var2 = gp.predict(self.Xnew, diag=True)
npt.assert_allclose(mu1, mu2, atol=0, rtol=1e-3)
npt.assert_allclose(var1, var2, atol=0, rtol=1e-3)
sigma = pm.HalfNormal("sigma", sigma=100, initval=50.0)
gp.marginal_likelihood("lik", self.x[:, None], self.x[:, None], self.y, sigma)
map_approx = pm.find_MAP(method="bfgs")

# Check MAP gets approximately correct result
npt.assert_allclose(self.map_full["c"], map_approx["c"], atol=0.01, rtol=0.1)
npt.assert_allclose(self.map_full["sigma"], map_approx["sigma"], atol=0.01, rtol=0.1)

# Check that predict (and conditional) work, include noise, with diagonal non-full pred var.
with model:
pred_mu_approx, pred_var_approx = gp.predict(
self.x_new[:, None], point=map_approx, pred_noise=True, diag=True
)
npt.assert_allclose(self.pred_mu, pred_mu_approx, atol=0.0, rtol=0.1)
npt.assert_allclose(self.pred_var, pred_var_approx, atol=0.0, rtol=0.1)

def testPredictCov(self):
with pm.Model() as model:
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
mean_func = pm.gp.mean.Constant(0.5)
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx="DTC")
f = gp.marginal_likelihood("f", self.X, self.X, self.y, self.sigma)
mu1, cov1 = self.gp.predict(self.Xnew, pred_noise=True)
mu2, cov2 = gp.predict(self.Xnew, pred_noise=True)
npt.assert_allclose(mu1, mu2, atol=0, rtol=1e-3)
npt.assert_allclose(cov1, cov2, atol=0, rtol=1e-3)
# Check that predict (and conditional) work, no noise, full pred covariance.
with model:
pred_mu_approx, pred_var_approx = gp.predict(
self.x_new[:, None], point=map_approx, pred_noise=True, diag=True
)
npt.assert_allclose(self.pred_mu, pred_mu_approx, atol=0.0, rtol=0.1)
npt.assert_allclose(self.pred_var, pred_var_approx, atol=0.0, rtol=0.1)


class TestGPAdditive:
Expand Down

0 comments on commit 8f02bea

Please sign in to comment.