Skip to content

Commit

Permalink
merge keras_mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Aug 23, 2021
1 parent 0e8ab4e commit d2f2d8f
Show file tree
Hide file tree
Showing 9 changed files with 899 additions and 59 deletions.
1 change: 1 addition & 0 deletions keras_cv_attention_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from keras_cv_attention_models import resnest
from keras_cv_attention_models import resnext
from keras_cv_attention_models import volo
from keras_cv_attention_models import mlp
147 changes: 100 additions & 47 deletions keras_cv_attention_models/aotnet/aotnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
HALO_BLOCK_SIZE = 4
CONV_KERNEL_INITIALIZER = tf.keras.initializers.VarianceScaling(scale=2.0, mode="fan_out", distribution="truncated_normal")


Expand All @@ -26,20 +27,34 @@ def batchnorm_with_activation(inputs, activation="relu", zero_gamma=False, name=
return nn


def conv2d_no_bias(inputs, filters, kernel_size, strides=1, padding="VALID", use_bias=False, name=None, **kwargs):
def conv2d_no_bias(inputs, filters, kernel_size, strides=1, padding="VALID", use_bias=False, groups=1, name=None, **kwargs):
pad = max(kernel_size) // 2 if isinstance(kernel_size, (list, tuple)) else kernel_size // 2
if padding.upper() == "SAME" and pad != 0:
inputs = keras.layers.ZeroPadding2D(padding=pad, name=name and name + "pad")(inputs)
return keras.layers.Conv2D(
filters,
kernel_size,
strides=strides,
padding="VALID",
use_bias=use_bias,
kernel_initializer=CONV_KERNEL_INITIALIZER,
name=name and name + "conv",
**kwargs,
)(inputs)

groups = groups if groups != 0 else 1
if groups == filters:
return keras.layers.DepthwiseConv2D(
kernel_size,
strides=strides,
padding="VALID",
use_bias=use_bias,
kernel_initializer=CONV_KERNEL_INITIALIZER,
name=name and name + "conv",
**kwargs
)(inputs)
else:
return keras.layers.Conv2D(
filters,
kernel_size,
strides=strides,
padding="VALID",
use_bias=use_bias,
groups=groups,
kernel_initializer=CONV_KERNEL_INITIALIZER,
name=name and name + "conv",
**kwargs,
)(inputs)


def se_module(inputs, se_ratio=0.25, activation="relu", use_bias=True, name=""):
Expand All @@ -57,6 +72,14 @@ def se_module(inputs, se_ratio=0.25, activation="relu", use_bias=True, name=""):
return keras.layers.Multiply()([inputs, se])


def drop_block(inputs, drop_rate=0, name=None):
if drop_rate > 0:
noise_shape = [None] + [1] * (len(inputs.shape) - 1) # [None, 1, 1, 1]
return keras.layers.Dropout(drop_rate, noise_shape=noise_shape, name=name and name + "drop")(inputs)
else:
return inputs


def anti_alias_downsample(inputs, kernel_size=3, strides=2, padding="SAME", trainable=False, name=None):
def anti_alias_downsample_initializer(weight_shape, dtype="float32"):
import numpy as np
Expand Down Expand Up @@ -109,48 +132,65 @@ def attn_block(inputs, filters, strides=1, attn_type=None, se_ratio=0, halo_bloc
return nn


def block(inputs, filters, preact=False, strides=1, conv_shortcut=False, expansion=4, attn_type=None, se_ratio=0, drop_rate=0, activation="relu", name=""):
expanded_filter = filters * expansion
halo_block_size = 4
if attn_type == "halo" and inputs.shape[1] % halo_block_size != 0: # HaloAttention
gap = halo_block_size - inputs.shape[1] % halo_block_size
pad_head, pad_tail = gap // 2, gap - gap // 2
inputs = keras.layers.ZeroPadding2D(padding=((pad_head, pad_tail), (pad_head, pad_tail)), name=name + "gap_pad")(inputs)

shortcut = keras.layers.MaxPooling2D(strides, strides=strides, padding="SAME")(inputs) if strides > 1 else inputs

if preact: # ResNetV2
inputs = batchnorm_with_activation(inputs, activation=activation, zero_gamma=False, name=name + "preact_")
def conv_shortcut_branch(inputs, expanded_filter, preact=False, strides=1, avg_pool_down=True, anti_alias_down=False, name=""):
if strides > 1 and avg_pool_down:
shortcut = keras.layers.AvgPool2D(strides, strides=strides, padding="SAME", name=name + "shorcut_down")(inputs)
strides = 1
elif strides > 1 and anti_alias_down:
shortcut = anti_alias_downsample(inputs, kernel_size=3, strides=2, name=name + "shorcut_down")
strides = 1
else:
shortcut = inputs
shortcut = conv2d_no_bias(shortcut, expanded_filter, 1, strides=strides, name=name + "shortcut_")
if not preact: # ResNet
shortcut = batchnorm_with_activation(shortcut, activation=None, zero_gamma=False, name=name + "shortcut_")
return shortcut

if conv_shortcut: # Set a new shortcut using conv
shortcut = keras.layers.AvgPool2D(strides, strides=strides, padding="SAME", name=name + "shorcut_pool")(inputs) if strides > 1 else inputs
# shortcut = anti_alias_downsample(inputs, kernel_size=3, strides=2, name=name + "shorcut_") if strides > 1 else inputs
shortcut = conv2d_no_bias(shortcut, expanded_filter, 1, strides=1, name=name + "shorcut_")
# shortcut = conv2d_no_bias(inputs, expanded_filter, 1, strides=strides, name=name + "shorcut_")
if not preact: # ResNet
shortcut = batchnorm_with_activation(shortcut, activation=None, zero_gamma=False, name=name + "shorcut_")

def deep_branch(inputs, filters, strides=1, expansion=4, attn_type=None, se_ratio=0, activation="relu", name=""):
expanded_filter = filters * expansion
if expansion > 1:
nn = conv2d_no_bias(inputs, filters, 1, strides=1, padding="VALID", name=name + "1_")
nn = conv2d_no_bias(inputs, filters, 1, strides=1, padding="VALID", name=name + "deep_1_")
else: # ResNet-RS like
nn = conv2d_no_bias(inputs, filters, 3, strides=1, padding="SAME", name=name + "1_") # Using strides=1 for not changing input shape
nn = conv2d_no_bias(inputs, filters, 3, strides=1, padding="SAME", name=name + "deep_1_") # Using strides=1 for not changing input shape
# nn = conv2d_no_bias(inputs, filters, 3, strides=strides, padding="SAME", name=name + "1_")
# strides = 1
nn = batchnorm_with_activation(nn, activation=activation, zero_gamma=False, name=name + "1_")
nn = attn_block(nn, filters, strides, attn_type, se_ratio / expansion, halo_block_size, True, activation, name=name + "2_")
nn = batchnorm_with_activation(nn, activation=activation, zero_gamma=False, name=name + "deep_1_")
nn = attn_block(nn, filters, strides, attn_type, se_ratio / expansion, HALO_BLOCK_SIZE, True, activation, name=name + "deep_2_")

if expansion > 1: # not ResNet-RS like
nn = conv2d_no_bias(nn, expanded_filter, 1, strides=1, padding="VALID", name=name + "3_")
nn = conv2d_no_bias(nn, expanded_filter, 1, strides=1, padding="VALID", name=name + "deep_3_")
return nn


def block(inputs, filters, preact=False, strides=1, conv_shortcut=False, expansion=4, attn_type=None, se_ratio=0, drop_rate=0, activation="relu", name=""):
expanded_filter = filters * expansion
if attn_type == "halo" and inputs.shape[1] % HALO_BLOCK_SIZE != 0: # HaloAttention
gap = HALO_BLOCK_SIZE - inputs.shape[1] % HALO_BLOCK_SIZE
pad_head, pad_tail = gap // 2, gap - gap // 2
inputs = keras.layers.ZeroPadding2D(padding=((pad_head, pad_tail), (pad_head, pad_tail)), name=name + "gap_pad")(inputs)

# print(">>>> shortcut:", shortcut.shape, "nn:", nn.shape)
if drop_rate > 0:
nn = keras.layers.Dropout(drop_rate, noise_shape=(None, 1, 1, 1), name=name + "drop")(nn)
if preact: # ResNetV2
return keras.layers.Add(name=name + "add")([shortcut, nn])
pre_inputs = batchnorm_with_activation(inputs, activation=activation, zero_gamma=False, name=name + "preact_")
else:
nn = batchnorm_with_activation(nn, activation=None, zero_gamma=True, name=name + "3_")
nn = keras.layers.Add(name=name + "add")([shortcut, nn])
return keras.layers.Activation(activation, name=name + "out")(nn)
pre_inputs = inputs

deep = deep_branch(pre_inputs, filters, strides, expansion, attn_type, se_ratio, activation=activation, name=name)

if conv_shortcut: # Set a new shortcut using conv
shortcut = conv_shortcut_branch(pre_inputs, expanded_filter, preact, strides, avg_pool_down=True, anti_alias_down=False, name=name)
else:
shortcut = keras.layers.MaxPooling2D(strides, strides=strides, padding="SAME")(inputs) if strides > 1 else inputs

# print(">>>> shortcut:", shortcut.shape, "deep:", deep.shape)
if preact: # ResNetV2
deep = drop_block(deep, drop_rate)
return keras.layers.Add(name=name + "add")([shortcut, deep])
else:
deep = batchnorm_with_activation(deep, activation=None, zero_gamma=True, name=name + "3_")
deep = drop_block(deep, drop_rate)
out = keras.layers.Add(name=name + "add")([shortcut, deep])
return keras.layers.Activation(activation, name=name + "out")(out)


def stack1(inputs, blocks, filters, preact=False, strides=2, expansion=4, attn_types=None, se_ratio=0, stack_drop=0, activation="relu", name=""):
Expand Down Expand Up @@ -183,13 +223,19 @@ def stack2(inputs, blocks, filters, preact=True, strides=2, expansion=4, attn_ty
return nn


def stem(inputs, stem_width, activation="relu", deep_stem=False, name=""):
def stem(inputs, stem_width, activation="relu", deep_stem=False, quad_stem=False, name=""):
if deep_stem:
nn = conv2d_no_bias(inputs, stem_width, 3, strides=2, padding="same", name=name + "1_")
nn = batchnorm_with_activation(nn, activation=activation, name=name + "1_")
nn = conv2d_no_bias(nn, stem_width, 3, strides=1, padding="same", name=name + "2_")
nn = batchnorm_with_activation(nn, activation=activation, name=name + "2_")
nn = conv2d_no_bias(nn, stem_width * 2, 3, strides=1, padding="same", name=name + "3_")
elif quad_stem:
nn = conv2d_no_bias(inputs, stem_width // 4, 3, strides=2, padding="same", name=name + "1_")
nn = conv2d_no_bias(nn, stem_width // 2, 3, strides=1, padding="same", name=name + "2_")
nn = conv2d_no_bias(nn, stem_width, 3, strides=1, padding="same", name=name + "3_")
nn = batchnorm_with_activation(nn, activation=activation, name=name + "1_")
nn = conv2d_no_bias(nn, stem_width * 2, 3, strides=2, padding="same", name=name + "4_")
else:
nn = conv2d_no_bias(inputs, stem_width, 7, strides=2, padding="same", name=name)
return nn
Expand All @@ -203,20 +249,22 @@ def AotNet(
out_channels=[64, 128, 256, 512],
stem_width=64,
deep_stem=False,
quad_stem=False,
stem_downsample=True,
attn_types=None,
expansion=4,
se_ratio=0, # (0, 1)
num_features=0,
input_shape=(224, 224, 3),
num_classes=1000,
activation="relu",
drop_connect_rate=0,
classifier_activation="softmax",
model_name="aotnet",
**kwargs
kwargs=None
):
inputs = keras.layers.Input(shape=input_shape)
nn = stem(inputs, stem_width, activation=activation, deep_stem=deep_stem, name="stem_")
nn = stem(inputs, stem_width, activation=activation, deep_stem=deep_stem, quad_stem=quad_stem, name="stem_")

if not preact:
nn = batchnorm_with_activation(nn, activation=activation, name="stem_")
Expand All @@ -234,12 +282,17 @@ def AotNet(
stack_drop = (stack_drop_s, stack_drop_e)
attn_type = attn_types[id] if isinstance(attn_types, (list, tuple)) else attn_types
cur_se_ratio = se_ratio[id] if isinstance(se_ratio, (list, tuple)) else se_ratio
nn = stack(nn, num_block, out_channel, preact, stride, expansion, attn_type, cur_se_ratio, stack_drop, activation, name=name)
cur_expansion = expansion[id] if isinstance(expansion, (list, tuple)) else expansion
nn = stack(nn, num_block, out_channel, preact, stride, cur_expansion, attn_type, cur_se_ratio, stack_drop, activation, name=name)
global_block_id += num_block

if preact:
if preact: # resnetv2 like
nn = batchnorm_with_activation(nn, activation=activation, zero_gamma=False, name="post_")

if num_features != 0: # efficientnet like
nn = conv2d_no_bias(nn, num_features, 1, strides=1, name="features_")
nn = batchnorm_with_activation(nn, activation=activation, name="features_")

if num_classes > 0:
nn = keras.layers.GlobalAveragePooling2D(name="avg_pool")(nn)
nn = keras.layers.Dense(num_classes, activation=classifier_activation, name="predictions")(nn)
Expand Down
98 changes: 98 additions & 0 deletions keras_cv_attention_models/mlp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# ___Keras MLP___
<!-- TOC depthFrom:1 depthTo:6 withLinks:1 updateOnSave:1 orderedList:0 -->

- [Keras_mlp](#kerasmlp)
- [Usage](#usage)
- [MLP mixer](#mlp-mixer)
- [ResMLP](#resmlp)
- [GMLP](#gmlp)

<!-- /TOC -->
***

## Usage
- **Basic usage**
```py
from keras_cv_attention_models import mlp
# Will download and load `imagenet` pretrained weights.
# Model weight is loaded with `by_name=True, skip_mismatch=True`.
mm = mlp.MLPMixerB16(num_classes=1000, pretrained="imagenet")

# Run prediction
import tensorflow as tf
from tensorflow import keras
from skimage.data import chelsea # Chelsea the cat
imm = keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='tf') # model="tf" or "torch"
pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy()
print(keras.applications.imagenet_utils.decode_predictions(pred)[0])
# [('n02124075', 'Egyptian_cat', 0.9568315), ('n02123045', 'tabby', 0.017994137), ...]
```
For `"imagenet21k"` pre-trained models, actual `num_classes` is `21843`.
- **Exclude model top layers** by set `num_classes=0`.
```py
from keras_cv_attention_models import mlp
mm = mlp.ResMLP_B24(num_classes=0, pretrained="imagenet22k")
print(mm.output_shape)
# (None, 784, 768)

mm.save('resmlp_b24_imagenet22k-notop.h5')
```
## MLP mixer
- [PDF 2105.01601 MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/pdf/2105.01601.pdf).
- [Github google-research/vision_transformer](https://github.com/google-research/vision_transformer#available-mixer-models).
- **Models** `Top1 Acc` is `Pre-trained on JFT-300M` model accuray on `ImageNet 1K` from paper.
| Model | Params | Top1 Acc | ImageNet | Imagenet21k | ImageNet SAM |
| ----------- | ------ | -------- | --------------- | ------------------ | ------------------- |
| MLPMixerS32 | 19.1M | 68.70 | | | |
| MLPMixerS16 | 18.5M | 73.83 | | | |
| MLPMixerB32 | 60.3M | 75.53 | | | [b32_imagenet_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/mlp_mixer_b32_imagenet_sam.h5) |
| MLPMixerB16 | 59.9M | 80.00 | [b16_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/mlp_mixer_b16_imagenet.h5) | [b16_imagenet21k.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/mlp_mixer_b16_imagenet21k.h5) | [b16_imagenet_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/mlp_mixer_b16_imagenet_sam.h5) |
| MLPMixerL32 | 206.9M | 80.67 | | | |
| MLPMixerL16 | 208.2M | 84.82 | [l16_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/mlp_mixer_l16_imagenet.h5) | [l16_imagenet21k.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/mlp_mixer_l16_imagenet21k.h5) | |
| - input 448 | 208.2M | 86.78 | | | |
| MLPMixerH14 | 432.3M | 86.32 | | | |
| - input 448 | 432.3M | 87.94 | | | |

| Specification | S/32 | S/16 | B/32 | B/16 | L/32 | L/16 | H/14 |
| -------------------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
| Number of layers | 8 | 8 | 12 | 12 | 24 | 24 | 32 |
| Patch resolution P×P | 32×32 | 16×16 | 32×32 | 16×16 | 32×32 | 16×16 | 14×14 |
| Hidden size C | 512 | 512 | 768 | 768 | 1024 | 1024 | 1280 |
| Sequence length S | 49 | 196 | 49 | 196 | 49 | 196 | 256 |
| MLP dimension DC | 2048 | 2048 | 3072 | 3072 | 4096 | 4096 | 5120 |
| MLP dimension DS | 256 | 256 | 384 | 384 | 512 | 512 | 640 |
- Parameter `pretrained` is added in value `[None, "imagenet", "imagenet21k", "imagenet_sam"]`. Default is `imagenet`.
- **Pre-training details**
- We pre-train all models using Adam with β1 = 0.9, β2 = 0.999, and batch size 4 096, using weight decay, and gradient clipping at global norm 1.
- We use a linear learning rate warmup of 10k steps and linear decay.
- We pre-train all models at resolution 224.
- For JFT-300M, we pre-process images by applying the cropping technique from Szegedy et al. [44] in addition to random horizontal flipping.
- For ImageNet and ImageNet-21k, we employ additional data augmentation and regularization techniques.
- In particular, we use RandAugment [12], mixup [56], dropout [42], and stochastic depth [19].
- This set of techniques was inspired by the timm library [52] and Touvron et al. [46].
- More details on these hyperparameters are provided in Supplementary B.
## ResMLP
- [PDF 2105.03404 ResMLP: Feedforward networks for image classification with data-efficient training](https://arxiv.org/pdf/2105.03404.pdf)
- [Github facebookresearch/deit](https://github.com/facebookresearch/deit)
- **Models** reloaded `imagenet` weights are the `distilled` version from official.
| Model | Params | Image resolution | Top1 Acc | ImageNet |
| ---------- | ------ | ---------------- | -------- | -------- |
| ResMLP12 | 15M | 224 | 77.8 | [resmlp12_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/resmlp12_imagenet.h5) |
| ResMLP24 | 30M | 224 | 80.8 | [resmlp24_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/resmlp24_imagenet.h5) |
| ResMLP36 | 116M | 224 | 81.1 | [resmlp36_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/resmlp36_imagenet.h5) |
| ResMLP_B24 | 129M | 224 | 83.6 | [resmlp_b24_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/resmlp_b24_imagenet.h5) |
| - imagenet22k | 129M | 224 | 84.4 | [resmlp_b24_imagenet22k.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/resmlp_b24_imagenet22k.h5) |

- Parameter `pretrained` is added in value `[None, "imagenet", "imagenet22k"]`, where `imagenet22k` means pre-trained on `imagenet21k` and fine-tuned on `imagenet`. Default is `imagenet`.
## GMLP
- [PDF 2105.08050 Pay Attention to MLPs](https://arxiv.org/pdf/2105.08050.pdf).
- Model weights reloaded from [Github timm/models/mlp_mixer](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mlp_mixer.py).
- **Models**
| Model | Params | Image resolution | Top1 Acc | ImageNet |
| ---------- | ------ | ---------------- | -------- | -------- |
| GMLPTiny16 | 6M | 224 | 72.3 | |
| GMLPS16 | 20M | 224 | 79.6 | [gmlp_s16_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/gmlp_s16_imagenet.h5) |
| GMLPB16 | 73M | 224 | 81.6 | |

- Parameter `pretrained` is added in value `[None, "imagenet"]`. Default is `imagenet`.
***
Loading

0 comments on commit d2f2d8f

Please sign in to comment.