Skip to content

Commit

Permalink
[ENH] Migrate ResNetClassifier from sktime-dl to sktime (#3881)
Browse files Browse the repository at this point in the history
Part of #3351. See also #3365 

Migrates `ResNetClassifier` from `sktime-dl` to `sktime`
  • Loading branch information
nilesh05apr authored Dec 11, 2022
1 parent aea3e7a commit c10cedb
Show file tree
Hide file tree
Showing 4 changed files with 381 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@
"design"
]
},
{
"login": "nilesh05apr",
"name": "Nilesh Kumar",
"avatar_url": "https://avatars.githubusercontent.com/u/65773314?v=4",
"profile": "https://github.com/nilesh05apr"
"contributions": [
"code",
]
},
{
"login": "MatthewMiddlehurst",
"name": "Matthew Middlehurst",
Expand Down
215 changes: 215 additions & 0 deletions sktime/classification/deep_learning/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# -*- coding: utf-8 -*-
"""Residual Network (ResNet) for classification."""

__author__ = ["James-Large", "AurumnPegasus", "nilesh05apr"]
__all__ = ["ResNetClassifier"]

from sklearn.utils import check_random_state

from sktime.classification.deep_learning.base import BaseDeepClassifier
from sktime.networks.resnet import ResNetNetwork
from sktime.utils.validation._dependencies import _check_dl_dependencies

_check_dl_dependencies(severity="warning")


class ResNetClassifier(BaseDeepClassifier):
"""
Residual Neural Network as described in [1].
Parameters
----------
n_epochs : int, default = 1500
the number of epochs to train the model
batch_size : int, default = 16
the number of samples per gradient update.
random_state : int or None, default=None
Seed for random number generation.
verbose : boolean, default = False
whether to output extra information
loss : string, default="mean_squared_error"
fit parameter for the keras model
optimizer : keras.optimizer, default=keras.optimizers.Adam(),
metrics : list of strings, default=["accuracy"],
activation : string or a tf callable, default="sigmoid"
Activation function used in the output linear layer.
List of available activation functions:
https://keras.io/api/layers/activations/
use_bias : boolean, default = True
whether the layer uses a bias vector.
optimizer : keras.optimizers object, default = Adam(lr=0.01)
specify the optimizer and the learning rate to be used.
Notes
-----
Adapted from the implementation from source code
https://github.com/hfawaz/dl-4-tsc/blob/master/classifiers/resnet.py
References
----------
.. [1] Wang et. al, Time series classification from
scratch with deep neural networks: A strong baseline,
International joint conference on neural networks (IJCNN), 2017.
Examples
--------
>>> from sktime.classification.deep_learning.resnet import ResNetClassifier
>>> from sktime.datasets import load_unit_test
>>> X_train, y_train = load_unit_test(split="train")
>>> clf = ResNetClassifier(n_epochs=20, bacth_size=4) # doctest: +SKIP
>>> clf.fit(X_train, Y_train) # doctest: +SKIP
ResNetClassifier(...)
"""

_tags = {"python_dependencies": ["tensorflow"]}

def __init__(
self,
n_epochs=1500,
callbacks=None,
verbose=False,
loss="categorical_crossentropy",
metrics=None,
batch_size=16,
random_state=None,
activation="sigmoid",
use_bias=True,
optimizer=None,
):
_check_dl_dependencies(severity="error")
super(ResNetClassifier, self).__init__()
self.n_epochs = n_epochs
self.callbacks = callbacks
self.verbose = verbose
self.loss = loss
self.metrics = metrics
self.batch_size = batch_size
self.random_state = random_state
self.activation = activation
self.use_bias = use_bias
self.optimizer = optimizer
self.history = None
self._network = ResNetNetwork(random_state=random_state)

def build_model(self, input_shape, n_classes, **kwargs):
"""Construct a compiled, un-trained, keras model that is ready for training.
In sktime, time series are stored in numpy arrays of shape (d,m), where d
is the number of dimensions, m is the series length. Keras/tensorflow assume
data is in shape (m,d). This method also assumes (m,d). Transpose should
happen in fit.
Parameters
----------
input_shape : tuple
The shape of the data fed into the input layer, should be (m,d)
n_classes: int
The number of classes, which becomes the size of the output layer
Returns
-------
output : a compiled Keras Model
"""
import tensorflow as tf
from tensorflow import keras

tf.random.set_seed(self.random_state)

self.optimizer_ = (
keras.optimizers.Adam(learning_rate=0.01)
if self.optimizer is None
else self.optimizer
)

if self.metrics is None:
metrics = ["accuracy"]
else:
metrics = self.metrics

input_layer, output_layer = self._network.build_network(input_shape, **kwargs)

output_layer = keras.layers.Dense(
units=n_classes, activation=self.activation, use_bias=self.use_bias
)(output_layer)

model = keras.models.Model(inputs=input_layer, outputs=output_layer)
model.compile(
loss=self.loss,
optimizer=self.optimizer_,
metrics=metrics,
)

return model

def _fit(self, X, y):
"""Fit the classifier on the training set (X, y).
Parameters
----------
X : np.ndarray of shape = (n_instances (n), n_dimensions (d), series_length (m))
The training input samples.
y : np.ndarray of shape n
The training data class labels.
Returns
-------
self : object
"""
if self.callbacks is None:
self._callbacks = []

y_onehot = self.convert_y_to_keras(y)
# Transpose to conform to Keras input style.
X = X.transpose(0, 2, 1)

check_random_state(self.random_state)
self.input_shape = X.shape[1:]
self.model_ = self.build_model(self.input_shape, self.n_classes_)
if self.verbose:
self.model_.summary()
self.history = self.model_.fit(
X,
y_onehot,
batch_size=self.batch_size,
epochs=self.n_epochs,
verbose=self.verbose,
callbacks=self._callbacks,
)
return self

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
For classifiers, a "default" set of parameters should be provided for
general testing, and a "results_comparison" set for comparing against
previously recorded results if the general set does not produce suitable
probabilities to compare against.
Returns
-------
params : dict or list of dict, default={}
Parameters to create testing instances of the class.
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`.
"""
param1 = {
"n_epochs": 10,
"batch_size": 4,
"use_bias": False,
}

param2 = {
"n_epochs": 12,
"batch_size": 6,
"use_bias": True,
}

return [param1, param2]
154 changes: 154 additions & 0 deletions sktime/networks/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# -*- coding: utf-8 -*-
"""Residual Network (ResNet) (minus the final output layer)."""

__author__ = ["James Large", "Withington", "nilesh05apr"]

from sktime.networks.base import BaseDeepNetwork
from sktime.utils.validation._dependencies import _check_dl_dependencies


class ResNetNetwork(BaseDeepNetwork):
"""
Establish the network structure for a ResNet.
Adapted from the implementations used in [1]
Parameters
----------
random_state : int, optional (default = 0)
The random seed to use random activities.
Notes
-----
Adpated from the implementation source code
https://github.com/hfawaz/dl-4-tsc/blob/master/classifiers/resnet.py
References
----------
.. [1] H. Fawaz, G. B. Lanckriet, F. Petitjean, and L. Idoumghar,
Network originally defined in:
@inproceedings{wang2017time, title={Time series classification from
scratch with deep neural networks: A strong baseline}, author={Wang,
Zhiguang and Yan, Weizhong and Oates, Tim}, booktitle={2017
International joint conference on neural networks (IJCNN)}, pages={
1578--1585}, year={2017}, organization={IEEE} }
"""

_tags = {"python_dependencies": ["tensorflow", "keras-self-attention"]}

def __init__(self, random_state=0):
_check_dl_dependencies(severity="error")
super(ResNetNetwork, self).__init__()
self.random_state = random_state

def build_network(self, input_shape, **kwargs):
"""
Construct a network and return its input and output layers.
Arguments
---------
input_shape : tuple of shape = (series_length (m), n_dimensions (d))
The shape of the data fed into the input layer.
Returns
-------
input_layer : keras.layers.Input
The input layer of the network.
output_layer : keras.layers.Layer
The output layer of the network.
"""
from tensorflow import keras

n_feature_maps = 64

input_layer = keras.layers.Input(input_shape)

# 1st residual block

conv_x = keras.layers.Conv1D(
filters=n_feature_maps, kernel_size=8, padding="same"
)(input_layer)
conv_x = keras.layers.BatchNormalization()(conv_x)
conv_x = keras.layers.Activation("relu")(conv_x)

conv_y = keras.layers.Conv1D(
filters=n_feature_maps, kernel_size=5, padding="same"
)(conv_x)
conv_y = keras.layers.BatchNormalization()(conv_y)
conv_y = keras.layers.Activation("relu")(conv_y)

conv_z = keras.layers.Conv1D(
filters=n_feature_maps, kernel_size=3, padding="same"
)(conv_y)
conv_z = keras.layers.BatchNormalization()(conv_z)

# expand channels for the sum
shortcut_y = keras.layers.Conv1D(
filters=n_feature_maps, kernel_size=1, padding="same"
)(input_layer)
shortcut_y = keras.layers.BatchNormalization()(shortcut_y)

output_block_1 = keras.layers.add([shortcut_y, conv_z])
output_block_1 = keras.layers.Activation("relu")(output_block_1)

# 2nd residual block

conv_x = keras.layers.Conv1D(
filters=n_feature_maps * 2, kernel_size=8, padding="same"
)(output_block_1)
conv_x = keras.layers.BatchNormalization()(conv_x)
conv_x = keras.layers.Activation("relu")(conv_x)

conv_y = keras.layers.Conv1D(
filters=n_feature_maps * 2, kernel_size=5, padding="same"
)(conv_x)
conv_y = keras.layers.BatchNormalization()(conv_y)
conv_y = keras.layers.Activation("relu")(conv_y)

conv_z = keras.layers.Conv1D(
filters=n_feature_maps * 2, kernel_size=3, padding="same"
)(conv_y)
conv_z = keras.layers.BatchNormalization()(conv_z)

# expand channels for the sum
shortcut_y = keras.layers.Conv1D(
filters=n_feature_maps * 2, kernel_size=1, padding="same"
)(output_block_1)
shortcut_y = keras.layers.BatchNormalization()(shortcut_y)

output_block_2 = keras.layers.add([shortcut_y, conv_z])
output_block_2 = keras.layers.Activation("relu")(output_block_2)

# 3rd residual block

conv_x = keras.layers.Conv1D(
filters=n_feature_maps * 2, kernel_size=8, padding="same"
)(output_block_2)
conv_x = keras.layers.BatchNormalization()(conv_x)
conv_x = keras.layers.Activation("relu")(conv_x)

conv_y = keras.layers.Conv1D(
filters=n_feature_maps * 2, kernel_size=5, padding="same"
)(conv_x)
conv_y = keras.layers.BatchNormalization()(conv_y)
conv_y = keras.layers.Activation("relu")(conv_y)

conv_z = keras.layers.Conv1D(
filters=n_feature_maps * 2, kernel_size=3, padding="same"
)(conv_y)
conv_z = keras.layers.BatchNormalization()(conv_z)

# no need to expand channels because they are equal
shortcut_y = keras.layers.BatchNormalization()(output_block_2)

output_block_3 = keras.layers.add([shortcut_y, conv_z])
output_block_3 = keras.layers.Activation("relu")(output_block_3)

# global average pooling

gap_layer = keras.layers.GlobalAveragePooling1D()(output_block_3)

return input_layer, gap_layer
3 changes: 3 additions & 0 deletions sktime/tests/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@
"test_save_estimators_to_file",
],
# `test_fit_idempotent` fails with `AssertionError`, see #3616
"ResNetClassifier": [
"test_fit_idempotent",
],
"CNNClassifier": [
"test_fit_idempotent",
],
Expand Down

0 comments on commit c10cedb

Please sign in to comment.