-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for Neural Network Embedding #895
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
xgboost==0.6a2 | ||
scikit-mdr==0.4.4 | ||
skrebate==0.3.4 | ||
tensorflow>=1.12.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from sklearn.datasets import make_classification, make_regression | ||
from tpot.builtins import EmbeddingEstimator | ||
from sklearn.neural_network import MLPClassifier, MLPRegressor | ||
|
||
|
||
def test_EmbeddingClassifier_1(): | ||
"""Assert that Embedding for classification works as expected.""" | ||
X, y = make_classification(random_state=1) | ||
cs = EmbeddingEstimator(MLPClassifier(random_state=1, tol=0.9)) | ||
X_transformed = cs.fit_transform(X, y) | ||
|
||
# 20 features + 100 embedding size | ||
assert X_transformed.shape[1] == 120 | ||
|
||
|
||
def test_EmbeddingClassifier_2(): | ||
"""Assert that correct embedding layer is selected (classifier).""" | ||
X, y = make_classification(random_state=1) | ||
cs = EmbeddingEstimator( | ||
MLPClassifier(hidden_layer_sizes=[20, 10], random_state=1, tol=0.9) | ||
) | ||
cs_2 = EmbeddingEstimator( | ||
MLPClassifier(hidden_layer_sizes=[20, 10], random_state=1, tol=0.9), | ||
embedding_layer=1, | ||
) | ||
X_transformed = cs.fit_transform(X, y) | ||
X_transformed_2 = cs_2.fit_transform(X, y) | ||
|
||
assert X_transformed.shape[1] == 30 # 20 features + 20 embedding size | ||
assert X_transformed_2.shape[1] == 40 # 20 features + 20 embedding size | ||
|
||
|
||
def test_EmbeddingRegressor_1(): | ||
"""Assert that Embedding for regressor works as expected.""" | ||
X, y = make_regression(n_features=20, random_state=1) | ||
cs = EmbeddingEstimator(MLPRegressor(random_state=1, tol=1000)) | ||
X_transformed = cs.fit_transform(X, y) | ||
|
||
# 20 features + 100 embedding size | ||
assert X_transformed.shape[1] == 120 | ||
|
||
|
||
def test_EmbeddingRegressor_2(): | ||
"""Assert that correct embedding layer is selected (regressor).""" | ||
X, y = make_regression(n_features=20, random_state=1) | ||
cs = EmbeddingEstimator( | ||
MLPRegressor(hidden_layer_sizes=[20, 10], random_state=1, tol=1000) | ||
) | ||
cs_2 = EmbeddingEstimator( | ||
MLPRegressor(hidden_layer_sizes=[20, 10], random_state=1, tol=1000), | ||
embedding_layer=1, | ||
) | ||
X_transformed = cs.fit_transform(X, y) | ||
X_transformed_2 = cs_2.fit_transform(X, y) | ||
|
||
assert X_transformed.shape[1] == 30 # 20 features + 20 embedding size | ||
assert X_transformed_2.shape[1] == 40 # 20 features + 20 embedding size | ||
|
||
|
||
def test_EmbeddingKeras(): | ||
"""Check that this works also for keras models""" | ||
try: | ||
import tensorflow as tf | ||
except ImportError: | ||
tf = None | ||
if tf is None: | ||
return | ||
from tensorflow.keras import backend as K | ||
import tensorflow.keras as keras | ||
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier | ||
from tensorflow.keras.models import Sequential | ||
from tensorflow.keras.layers import Dense, Activation | ||
|
||
def make_model(input_shape): | ||
model = Sequential() | ||
model.add(Dense(20, activation="relu", input_dim=input_shape)) | ||
model.add(Dense(15, activation="relu")) | ||
model.add(Dense(2, activation="softmax")) | ||
model.compile( | ||
optimizer="rmsprop", loss="categorical_crossentropy", metrics=["accuracy"] | ||
) | ||
return model | ||
|
||
X, y = make_classification(random_state=1) | ||
cs = EmbeddingEstimator(KerasClassifier(make_model), backend=K) | ||
cs_2 = EmbeddingEstimator( | ||
KerasClassifier(make_model), embedding_layer=-3, backend=K | ||
) | ||
X_transformed = cs.fit_transform(X, y, verbose=0) | ||
X_transformed_2 = cs_2.fit_transform(X, y, verbose=0) | ||
|
||
assert X_transformed.shape[1] == 35 # 20 features + 15 embedding size | ||
assert X_transformed_2.shape[1] == 40 # 20 features + 20 embedding size |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
"""This file is part of the TPOT library. | ||
|
||
TPOT was primarily developed at the University of Pennsylvania by: | ||
- Randal S. Olson ([email protected]) | ||
- Weixuan Fu ([email protected]) | ||
- Daniel Angell ([email protected]) | ||
- and many more generous open source contributors | ||
|
||
TPOT is free software: you can redistribute it and/or modify | ||
it under the terms of the GNU Lesser General Public License as | ||
published by the Free Software Foundation, either version 3 of | ||
the License, or (at your option) any later version. | ||
|
||
TPOT is distributed in the hope that it will be useful, | ||
but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
GNU Lesser General Public License for more details. | ||
|
||
You should have received a copy of the GNU Lesser General Public | ||
License along with TPOT. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
""" | ||
|
||
import numpy as np | ||
from sklearn.base import BaseEstimator, TransformerMixin | ||
from sklearn.utils import check_array | ||
from sklearn.neural_network import MLPClassifier, MLPRegressor | ||
|
||
|
||
class EmbeddingEstimator(TransformerMixin, BaseEstimator): | ||
"""Meta-transformer for creating neural network embeddings as features. | ||
""" | ||
|
||
def __init__(self, estimator, embedding_layer=None, backend=None): | ||
"""Create a StackingEstimator object. | ||
|
||
Parameters | ||
---------- | ||
estimator: neural network model; either from sklearn or Keras-like. | ||
The estimator to generate embeddings. | ||
embedding_layer: the particular layer used as the embedding. | ||
By default we use the second last layer. Layers are counted with | ||
input layer being `0th` layer; negative indices are allowed. | ||
backend: (optional), the backend we use to query the neural network. | ||
Not required if using scikit-learn interface. | ||
Currently only supports keras-like interface (incl. tensorflow) | ||
""" | ||
second_last_layer = -2 | ||
self.estimator = estimator | ||
self.embedding_layer = ( | ||
second_last_layer if embedding_layer is None else embedding_layer | ||
) | ||
self.backend = backend | ||
|
||
def fit(self, X, y=None, **fit_params): | ||
"""Fit the StackingEstimator meta-transformer. | ||
|
||
Parameters | ||
---------- | ||
X: array-like of shape (n_samples, n_features) | ||
The training input samples. | ||
y: array-like, shape (n_samples,) | ||
The target values (integers that correspond to classes in classification, real numbers in regression). | ||
fit_params: | ||
Other estimator-specific parameters. | ||
|
||
Returns | ||
------- | ||
self: object | ||
Returns a copy of the estimator | ||
""" | ||
if not issubclass(self.estimator.__class__, MLPClassifier) and not issubclass( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Q: how would we support pure tensorflow or pytorch? I've used Keras here to try to use the Keras scikit-learn wrapper There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for this PR. Keras scikit-learn wrapper is good so far. |
||
self.estimator.__class__, MLPRegressor | ||
): | ||
input_shape = X.shape[1] | ||
self.estimator.sk_params["input_shape"] = input_shape | ||
self.estimator.check_params(self.estimator.sk_params) | ||
self.estimator.fit(X, y, **fit_params) | ||
return self | ||
|
||
def transform(self, X): | ||
"""Transform data by adding embedding as features. | ||
|
||
Parameters | ||
---------- | ||
X: numpy ndarray, {n_samples, n_components} | ||
New data, where n_samples is the number of samples and n_components is the number of components. | ||
|
||
Returns | ||
------- | ||
X_transformed: array-like, shape (n_samples, n_features + embedding) where embedding is the size of the embedding layer | ||
The transformed feature set. | ||
""" | ||
X = check_array(X) | ||
X_transformed = np.copy(X) | ||
# add class probabilities as a synthetic feature | ||
if issubclass(self.estimator.__class__, MLPClassifier) or issubclass( | ||
self.estimator.__class__, MLPRegressor | ||
): | ||
X_transformed = np.hstack( | ||
(self._embedding_mlp(self.estimator, X), X_transformed) | ||
) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Chose not to do checking here otherwise you'll need to do an import at the top; which adds whatever flavour of the month neural network library to be a hard dependency to this. I don't have better ideas at this point but happy for any suggestions |
||
X_transformed = np.hstack( | ||
(self._embedding_keras(self.estimator, X), X_transformed) | ||
) | ||
|
||
return X_transformed | ||
|
||
def _embedding_mlp(self, estimator, X): | ||
# see also BaseMultilayerPerceptron._predict from | ||
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/neural_network/multilayer_perceptron.py | ||
X = check_array(X, accept_sparse=["csr", "csc", "coo"]) | ||
|
||
# Make sure self.hidden_layer_sizes is a list | ||
hidden_layer_sizes = estimator.hidden_layer_sizes | ||
if not hasattr(hidden_layer_sizes, "__iter__"): | ||
hidden_layer_sizes = [hidden_layer_sizes] | ||
hidden_layer_sizes = list(hidden_layer_sizes) | ||
|
||
layer_units = [X.shape[1]] + hidden_layer_sizes + [estimator.n_outputs_] | ||
|
||
# Initialize layers | ||
activations = [X] | ||
|
||
for i in range(estimator.n_layers_ - 1): | ||
activations.append(np.empty((X.shape[0], layer_units[i + 1]))) | ||
# forward propagate | ||
estimator._forward_pass(activations) | ||
y_embedding = activations[self.embedding_layer] | ||
|
||
return y_embedding | ||
|
||
def _embedding_keras(self, estimator, X): | ||
X = check_array(X, accept_sparse=["csr", "csc", "coo"]) | ||
get_embedding = self.backend.function( | ||
[estimator.model.layers[0].input], | ||
[estimator.model.layers[self.embedding_layer].output], | ||
) | ||
return get_embedding([X])[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you guys prefer raising a SkipTest here or leaving it like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SkipTest should be OK since TF should be a optional dependency.