Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MGP #65

Open
wants to merge 60 commits into
base: master
Choose a base branch
from
Open

MGP #65

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
62febc7
Approximate marginalisation during prediction time using Laplace appr…
gpfins Jul 22, 2017
2f554cc
Adding docstring
gpfins Jul 22, 2017
267fdad
Small example
gpfins Jul 22, 2017
be294ee
Adding predictions for multiple points
gpfins Jul 23, 2017
56e0811
Add notebook for MGP
gpfins Jul 23, 2017
8cb9e83
Update notebook MGP
gpfins Jul 23, 2017
5eff800
Make use of the free_vars instead of collecting the variables manually
gpfins Jul 23, 2017
4a54add
Remove unnecessary functions
gpfins Jul 23, 2017
1e6d81d
Can hanle predict_f and predict_y
gpfins Jul 23, 2017
850df94
Can handle multi-output GP
gpfins Aug 6, 2017
8c4c239
Add predict_density
gpfins Aug 6, 2017
7869161
Models tests
gpfins Aug 6, 2017
2242178
Models tests
gpfins Aug 6, 2017
cf672a5
Approximate marginalisation during prediction time using Laplace appr…
gpfins Jul 22, 2017
8acfad6
Adding docstring
gpfins Jul 22, 2017
66ff92c
Small example
gpfins Jul 22, 2017
2316ac9
Adding predictions for multiple points
gpfins Jul 23, 2017
8f609cd
Add notebook for MGP
gpfins Jul 23, 2017
d884538
Update notebook MGP
gpfins Jul 23, 2017
c0a41e7
Make use of the free_vars instead of collecting the variables manually
gpfins Jul 23, 2017
c83eb73
Remove unnecessary functions
gpfins Jul 23, 2017
c0f5c00
Can hanle predict_f and predict_y
gpfins Jul 23, 2017
958eb49
Can handle multi-output GP
gpfins Aug 6, 2017
0657328
Add predict_density
gpfins Aug 6, 2017
8f79822
Models tests
gpfins Aug 6, 2017
2c940bc
Models tests
gpfins Aug 6, 2017
3b3bd35
Approximate marginalisation during prediction time using Laplace appr…
gpfins Jul 22, 2017
e9d3fd2
Adding docstring
gpfins Jul 22, 2017
908cadd
Small example
gpfins Jul 22, 2017
451f3dd
Adding predictions for multiple points
gpfins Jul 23, 2017
673fbd4
Add notebook for MGP
gpfins Jul 23, 2017
d24da18
Update notebook MGP
gpfins Jul 23, 2017
a31dfe7
Make use of the free_vars instead of collecting the variables manually
gpfins Jul 23, 2017
43daf62
Remove unnecessary functions
gpfins Jul 23, 2017
4e69a7f
Can hanle predict_f and predict_y
gpfins Jul 23, 2017
0b14297
Can handle multi-output GP
gpfins Aug 6, 2017
c85fc6e
Add predict_density
gpfins Aug 6, 2017
96c6307
Models tests
gpfins Aug 6, 2017
01a67d1
Models tests
gpfins Aug 6, 2017
9630095
Merge remote-tracking branch 'origin/mgp' into mgp
gpfins Aug 6, 2017
cd6e969
Add objective to __init__
gpfins Aug 7, 2017
34ca364
Delete testmgp.py
gpfins Aug 7, 2017
19a6937
Update .travis.yml
gpfins Aug 7, 2017
94e1028
Update test_models.py
gpfins Aug 8, 2017
f2ac9e8
Merge branch 'master' into mgp
javdrher Aug 8, 2017
982a957
Refactoring DataScaler to parameterized
javdrher Aug 9, 2017
2cc1855
Bugfix in modelwrapper, make sure compile flag is set correctly
javdrher Aug 9, 2017
948965d
Adding tests for the wrapping super class
javdrher Aug 9, 2017
4472079
Forgot adding the file
javdrher Aug 9, 2017
b4c3615
Bugfix
gpfins Aug 9, 2017
9b95f3c
Merge branch 'model_wrapper_class' into mgp
javdrher Aug 9, 2017
f8bc92f
Moving rowwise_gradients to tf_wraps
javdrher Aug 9, 2017
79c7761
Added test directives to verify variance of MGP
javdrher Aug 10, 2017
8e7d9b3
Enable wrapping of modelwrappers
javdrher Aug 10, 2017
e423653
Documentation
javdrher Aug 10, 2017
2302673
Merge branch 'model_wrapper_class' into mgp
javdrher Aug 10, 2017
2bf2c44
Solving asserts, and allow modelwrappers as input in acquisition
javdrher Aug 13, 2017
715b99f
Merge branch 'model_wrapper_class' into mgp
javdrher Aug 13, 2017
0301ffa
Improving setattr in modelwrapper to cope with some inconsistensies d…
javdrher Aug 13, 2017
c2764eb
Merge branch 'model_wrapper_class' into mgp
javdrher Aug 13, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ python:
cache: pip
install:
- pip install -U pip wheel
- pip install tensorflow==1.0.1
- pip install tensorflow==1.3.0rc0
- pip install --process-dependency-links .
- pip install .[test]
- pip install codecov
Expand Down
2 changes: 2 additions & 0 deletions GPflowOpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@
from . import transforms
from . import scaling
from . import objective
from . import models
from . import pareto
from . import models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import models twice?

6 changes: 5 additions & 1 deletion GPflowOpt/acquisition/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

from ..scaling import DataScaler
from ..domain import UnitCube
from ..models import ModelWrapper

from GPflow.param import Parameterized, AutoFlow, ParamList
from GPflow.model import Model
from GPflow import settings

import numpy as np
Expand Down Expand Up @@ -48,7 +50,9 @@ def __init__(self, models=[], optimize_restarts=5):
:param optimize_restarts: number of optimization restarts to use when training the models
"""
super(Acquisition, self).__init__()
self._models = ParamList([DataScaler(m) for m in np.atleast_1d(models).tolist()])
models = np.atleast_1d(models)
assert all(isinstance(model, (Model, ModelWrapper))for model in models)
self._models = ParamList([DataScaler(m) for m in models])
self._default_params = list(map(lambda m: m.get_free_state(), self._models))

assert (optimize_restarts >= 0)
Expand Down
1 change: 0 additions & 1 deletion GPflowOpt/acquisition/ei.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(self, model):
:param model: GPflow model (single output) representing our belief of the objective
"""
super(ExpectedImprovement, self).__init__(model)
assert (isinstance(model, Model))
self.fmin = DataHolder(np.zeros(1))
self.setup()

Expand Down
167 changes: 167 additions & 0 deletions GPflowOpt/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright 2017 Joachim van der Herten, Nicolas Knudde
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tf_wraps import rowwise_gradients

from GPflow.param import Parameterized, AutoFlow
from GPflow.model import Model, GPModel
from GPflow.likelihoods import Gaussian
from GPflow import settings

import tensorflow as tf

float_type = settings.dtypes.float_type


class ModelWrapper(Parameterized):
"""
Class for fast implementation of a wrapper for models defined in GPflow. Once wrapped, all lookups for attributes
which are not found in the wrapper class are automatically forwarded to the wrapped model.

To influence the I/O of methods on the wrapped class, simply implement the method in the wrapper and call the
appropriate methods on the wrapped class. Specific logic is included to make sure that if AutoFlow methods are
influenced following this pattern, the original AF storage (if existing) is unaffected and a new storage is added
to the subclass.
"""
def __init__(self, model):
"""
:param model: model to be wrapped
"""
super(ModelWrapper, self).__init__()

assert isinstance(model, (Model, ModelWrapper))
#: Wrapped model
self.wrapped = model

def __getattr__(self, item):
"""
If an attribute is not found in this class, it is searched in the wrapped model
"""
# Exception for AF storages, if a method with the same name exists in this class, do not find the cache
# in the wrapped model.
if item.endswith('_AF_storage'):
method = item[1:].rstrip('_AF_storage')
if method in dir(self):
raise AttributeError("{0} has no attribute {1}".format(self.__class__.__name__, item))
return getattr(self.wrapped, item)

def __setattr__(self, key, value):
"""
1) If setting :attr:`wrapped` attribute, point parent to this object (the datascaler).
2) If setting the recompilation attribute, always do this on the wrapped class.
"""
if key is 'wrapped':
object.__setattr__(self, key, value)
value.__setattr__('_parent', self)
return

try:
# If attribute is in this object, set it. Test by using getattribute instead of hasattr to avoid lookup in
# wrapped object.
self.__getattribute__(key)
super(ModelWrapper, self).__setattr__(key, value)
except AttributeError:
# Attribute is not in wrapper.
# In case no wrapped object is set yet (e.g. constructor), set in wrapper.
if 'wrapped' not in self.__dict__:
super(ModelWrapper, self).__setattr__(key, value)
return

if hasattr(self, key):
# Now use hasattr, we know getattribute already failed so if it returns true, it must be in the wrapped
# object. Hasattr is called on self instead of self.wrapped to account for the different handling of
# AF storages.
# Prefer setting the attribute in the wrapped object if exists.
setattr(self.wrapped, key, value)
else:
# If not, set in wrapper nonetheless.
super(ModelWrapper, self).__setattr__(key, value)

def __eq__(self, other):
return self.wrapped == other

def __str__(self, prepend=''):
return self.wrapped.__str__(prepend)


class MGP(ModelWrapper):
"""
Marginalisation of the hyperparameters during evaluation time using a Laplace Approximation
Key reference:

::

@article{Garnett:2013,
title={Active learning of linear embeddings for Gaussian processes},
author={Garnett, Roman and Osborne, Michael A and Hennig, Philipp},
journal={arXiv preprint arXiv:1310.6740},
year={2013}
}
"""

def __init__(self, model):
assert isinstance(model, GPModel), "Object has to be a GP model"
assert isinstance(model.likelihood, Gaussian), "Likelihood has to be Gaussian"
super(MGP, self).__init__(model)

def build_predict(self, fmean, fvar, theta):
h = tf.hessians(self.build_likelihood() + self.build_prior(), theta)[0]
L = tf.cholesky(-h)

N = tf.shape(fmean)[0]
D = tf.shape(fmean)[1]

fmeanf = tf.reshape(fmean, [N * D, 1]) # N*D x 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N x D x 1

:)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, has to be N*D x 1 so I can use rowwise_gradients, then I reshape later

fvarf = tf.reshape(fvar, [N * D, 1]) # N*D x 1

Dfmean = rowwise_gradients(fmeanf, theta) # N*D x k
Dfvar = rowwise_gradients(fvarf, theta) # N*D x k

tmp1 = tf.transpose(tf.matrix_triangular_solve(L, tf.transpose(Dfmean))) # N*D x k
tmp2 = tf.transpose(tf.matrix_triangular_solve(L, tf.transpose(Dfvar))) # N*D x k
return fmean, 4 / 3 * fvar + tf.reshape(tf.reduce_sum(tf.square(tmp1), axis=1), [N, D]) \
+ 1 / 3 / (fvar + 1E-3) * tf.reshape(tf.reduce_sum(tf.square(tmp2), axis=1), [N, D])

@AutoFlow((float_type, [None, None]))
def predict_f(self, Xnew):
"""
Compute the mean and variance of the latent function(s) at the points
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update doc string? marginalised around ...

Xnew.
"""
theta = self._predict_f_AF_storage['free_vars']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No other way, have to wait for GPflow issue

fmean, fvar = self.wrapped.build_predict(Xnew)
return self.build_predict(fmean, fvar, theta)

@AutoFlow((float_type, [None, None]))
def predict_y(self, Xnew):
"""
Compute the mean and variance of held-out data at the points Xnew
"""
theta = self._predict_y_AF_storage['free_vars']
pred_f_mean, pred_f_var = self.wrapped.build_predict(Xnew)
fmean, fvar = self.wrapped.likelihood.predict_mean_and_var(pred_f_mean, pred_f_var)
return self.build_predict(fmean, fvar, theta)

@AutoFlow((float_type, [None, None]), (float_type, [None, None]))
def predict_density(self, Xnew, Ynew):
"""
Compute the (log) density of the data Ynew at the points Xnew

Note that this computes the log density of the data individually,
ignoring correlations between them. The result is a matrix the same
shape as Ynew containing the log densities.
"""
theta = self._predict_density_AF_storage['free_vars']
pred_f_mean, pred_f_var = self.wrapped.build_predict(Xnew)
pred_f_mean, pred_f_var = self.build_predict(pred_f_mean, pred_f_var, theta)
return self.likelihood.predict_density(pred_f_mean, pred_f_var, Ynew)
62 changes: 23 additions & 39 deletions GPflowOpt/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from GPflow.param import DataHolder, AutoFlow, Parameterized
from GPflow.model import Model, GPModel
from GPflow.param import DataHolder, AutoFlow
from GPflow.model import GPModel
from GPflow import settings
import numpy as np
from .transforms import LinearTransform, DataTransform
from .domain import UnitCube
from .models import ModelWrapper

float_type = settings.dtypes.float_type


class DataScaler(GPModel):
class DataScaler(ModelWrapper):
"""
Model-wrapping class, primarily intended to assure the data in GPflow models is scaled. One DataScaler wraps one
GPflow model, and can scale the input as well as the output data. By default, if any kind of object attribute
Expand Down Expand Up @@ -59,13 +60,8 @@ def __init__(self, model, domain=None, normalize_Y=False):
:param normalize_Y: (default: False) enable automatic scaling of output values to zero mean and unit
variance.
"""
# model sanity checks
assert (model is not None)
assert (isinstance(model, GPModel))
self._parent = None

# Wrap model
self.wrapped = model
# model sanity checks, slightly stronger conditions than the wrapper
super(DataScaler, self).__init__(model)

# Initial configuration of the datascaler
n_inputs = model.X.shape[1]
Expand All @@ -74,34 +70,8 @@ def __init__(self, model, domain=None, normalize_Y=False):
self._normalize_Y = normalize_Y
self._output_transform = LinearTransform(np.ones(n_outputs), np.zeros(n_outputs))

# The assignments in the constructor of GPModel take care of initial re-scaling of model data.
super(DataScaler, self).__init__(model.X.value, model.Y.value, None, None, 1, name=model.name+"_datascaler")
del self.kern
del self.mean_function
del self.likelihood

def __getattr__(self, item):
"""
If an attribute is not found in this class, it is searched in the wrapped model
"""
return self.wrapped.__getattribute__(item)

def __setattr__(self, key, value):
"""
If setting :attr:`wrapped` attribute, point parent to this object (the datascaler)
"""
if key is 'wrapped':
object.__setattr__(self, key, value)
value.__setattr__('_parent', self)
return

super(DataScaler, self).__setattr__(key, value)

def __eq__(self, other):
return self.wrapped == other

def __str__(self, prepend=''):
return self.wrapped.__str__(prepend)
self.X = model.X.value
self.Y = model.Y.value

@property
def input_transform(self):
Expand Down Expand Up @@ -216,6 +186,20 @@ def build_predict(self, Xnew, full_cov=False):
f, var = self.wrapped.build_predict(self.input_transform.build_forward(Xnew), full_cov=full_cov)
return self.output_transform.build_backward(f), self.output_transform.build_backward_variance(var)

@AutoFlow((float_type, [None, None]))
def predict_f(self, Xnew):
"""
Compute the mean and variance of held-out data at the points Xnew
"""
return self.build_predict(Xnew)

@AutoFlow((float_type, [None, None]))
def predict_f_full_cov(self, Xnew):
"""
Compute the mean and variance of held-out data at the points Xnew
"""
return self.build_predict(Xnew, full_cov=True)

@AutoFlow((float_type, [None, None]))
def predict_y(self, Xnew):
"""
Expand All @@ -230,6 +214,6 @@ def predict_density(self, Xnew, Ynew):
"""
Compute the (log) density of the data Ynew at the points Xnew
"""
mu, var = self.build_predict(Xnew)
mu, var = self.wrapped.build_predict(self.input_transform.build_forward(Xnew))
Ys = self.output_transform.build_forward(Ynew)
return self.likelihood.predict_density(mu, var, Ys)
42 changes: 42 additions & 0 deletions GPflowOpt/tf_wraps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2017 Joachim van der Herten, Nicolas Knudde
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from GPflow import settings

float_type = settings.dtypes.float_type


def rowwise_gradients(Y, X):
"""
For a 2D Tensor Y, compute the derivative of each columns w.r.t a 2D tensor X.

This is done with while_loop, because of a known incompatibility between map_fn and gradients.
"""
num_rows = tf.shape(Y)[0]
num_feat = tf.shape(X)[0]

def body(old_grads, row):
g = tf.expand_dims(tf.gradients(Y[row], X)[0], axis=0)
new_grads = tf.concat([old_grads, g], axis=0)
return new_grads, row + 1

def cond(_, row):
return tf.less(row, num_rows)

shape_invariants = [tf.TensorShape([None, None]), tf.TensorShape([])]
grads, _ = tf.while_loop(cond, body, [tf.zeros([0, num_feat], float_type), tf.constant(0)],
shape_invariants=shape_invariants)

return grads
8 changes: 8 additions & 0 deletions doc/source/interfaces.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ Transform
:special-members:
.. autoclass:: GPflowOpt.transforms.DataTransform
:special-members:

ModelWrapper
------------
.. automodule:: GPflowOpt.models
:special-members:
.. autoclass:: GPflowOpt.models.ModelWrapper
:members:
:special-members:
Loading