diff --git a/keras_cv_attention_models/__init__.py b/keras_cv_attention_models/__init__.py index 5b7392b3..f70ea9c6 100644 --- a/keras_cv_attention_models/__init__.py +++ b/keras_cv_attention_models/__init__.py @@ -8,3 +8,4 @@ from keras_cv_attention_models import resnest from keras_cv_attention_models import resnext from keras_cv_attention_models import volo +from keras_cv_attention_models import mlp diff --git a/keras_cv_attention_models/aotnet/aotnet.py b/keras_cv_attention_models/aotnet/aotnet.py index 3581de28..e7bbb841 100644 --- a/keras_cv_attention_models/aotnet/aotnet.py +++ b/keras_cv_attention_models/aotnet/aotnet.py @@ -7,6 +7,7 @@ BATCH_NORM_DECAY = 0.9 BATCH_NORM_EPSILON = 1e-5 +HALO_BLOCK_SIZE = 4 CONV_KERNEL_INITIALIZER = tf.keras.initializers.VarianceScaling(scale=2.0, mode="fan_out", distribution="truncated_normal") @@ -26,20 +27,34 @@ def batchnorm_with_activation(inputs, activation="relu", zero_gamma=False, name= return nn -def conv2d_no_bias(inputs, filters, kernel_size, strides=1, padding="VALID", use_bias=False, name=None, **kwargs): +def conv2d_no_bias(inputs, filters, kernel_size, strides=1, padding="VALID", use_bias=False, groups=1, name=None, **kwargs): pad = max(kernel_size) // 2 if isinstance(kernel_size, (list, tuple)) else kernel_size // 2 if padding.upper() == "SAME" and pad != 0: inputs = keras.layers.ZeroPadding2D(padding=pad, name=name and name + "pad")(inputs) - return keras.layers.Conv2D( - filters, - kernel_size, - strides=strides, - padding="VALID", - use_bias=use_bias, - kernel_initializer=CONV_KERNEL_INITIALIZER, - name=name and name + "conv", - **kwargs, - )(inputs) + + groups = groups if groups != 0 else 1 + if groups == filters: + return keras.layers.DepthwiseConv2D( + kernel_size, + strides=strides, + padding="VALID", + use_bias=use_bias, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=name and name + "conv", + **kwargs + )(inputs) + else: + return keras.layers.Conv2D( + filters, + kernel_size, + strides=strides, + padding="VALID", + use_bias=use_bias, + groups=groups, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=name and name + "conv", + **kwargs, + )(inputs) def se_module(inputs, se_ratio=0.25, activation="relu", use_bias=True, name=""): @@ -57,6 +72,14 @@ def se_module(inputs, se_ratio=0.25, activation="relu", use_bias=True, name=""): return keras.layers.Multiply()([inputs, se]) +def drop_block(inputs, drop_rate=0, name=None): + if drop_rate > 0: + noise_shape = [None] + [1] * (len(inputs.shape) - 1) # [None, 1, 1, 1] + return keras.layers.Dropout(drop_rate, noise_shape=noise_shape, name=name and name + "drop")(inputs) + else: + return inputs + + def anti_alias_downsample(inputs, kernel_size=3, strides=2, padding="SAME", trainable=False, name=None): def anti_alias_downsample_initializer(weight_shape, dtype="float32"): import numpy as np @@ -109,48 +132,65 @@ def attn_block(inputs, filters, strides=1, attn_type=None, se_ratio=0, halo_bloc return nn -def block(inputs, filters, preact=False, strides=1, conv_shortcut=False, expansion=4, attn_type=None, se_ratio=0, drop_rate=0, activation="relu", name=""): - expanded_filter = filters * expansion - halo_block_size = 4 - if attn_type == "halo" and inputs.shape[1] % halo_block_size != 0: # HaloAttention - gap = halo_block_size - inputs.shape[1] % halo_block_size - pad_head, pad_tail = gap // 2, gap - gap // 2 - inputs = keras.layers.ZeroPadding2D(padding=((pad_head, pad_tail), (pad_head, pad_tail)), name=name + "gap_pad")(inputs) - - shortcut = keras.layers.MaxPooling2D(strides, strides=strides, padding="SAME")(inputs) if strides > 1 else inputs - - if preact: # ResNetV2 - inputs = batchnorm_with_activation(inputs, activation=activation, zero_gamma=False, name=name + "preact_") +def conv_shortcut_branch(inputs, expanded_filter, preact=False, strides=1, avg_pool_down=True, anti_alias_down=False, name=""): + if strides > 1 and avg_pool_down: + shortcut = keras.layers.AvgPool2D(strides, strides=strides, padding="SAME", name=name + "shorcut_down")(inputs) + strides = 1 + elif strides > 1 and anti_alias_down: + shortcut = anti_alias_downsample(inputs, kernel_size=3, strides=2, name=name + "shorcut_down") + strides = 1 + else: + shortcut = inputs + shortcut = conv2d_no_bias(shortcut, expanded_filter, 1, strides=strides, name=name + "shortcut_") + if not preact: # ResNet + shortcut = batchnorm_with_activation(shortcut, activation=None, zero_gamma=False, name=name + "shortcut_") + return shortcut - if conv_shortcut: # Set a new shortcut using conv - shortcut = keras.layers.AvgPool2D(strides, strides=strides, padding="SAME", name=name + "shorcut_pool")(inputs) if strides > 1 else inputs - # shortcut = anti_alias_downsample(inputs, kernel_size=3, strides=2, name=name + "shorcut_") if strides > 1 else inputs - shortcut = conv2d_no_bias(shortcut, expanded_filter, 1, strides=1, name=name + "shorcut_") - # shortcut = conv2d_no_bias(inputs, expanded_filter, 1, strides=strides, name=name + "shorcut_") - if not preact: # ResNet - shortcut = batchnorm_with_activation(shortcut, activation=None, zero_gamma=False, name=name + "shorcut_") +def deep_branch(inputs, filters, strides=1, expansion=4, attn_type=None, se_ratio=0, activation="relu", name=""): + expanded_filter = filters * expansion if expansion > 1: - nn = conv2d_no_bias(inputs, filters, 1, strides=1, padding="VALID", name=name + "1_") + nn = conv2d_no_bias(inputs, filters, 1, strides=1, padding="VALID", name=name + "deep_1_") else: # ResNet-RS like - nn = conv2d_no_bias(inputs, filters, 3, strides=1, padding="SAME", name=name + "1_") # Using strides=1 for not changing input shape + nn = conv2d_no_bias(inputs, filters, 3, strides=1, padding="SAME", name=name + "deep_1_") # Using strides=1 for not changing input shape # nn = conv2d_no_bias(inputs, filters, 3, strides=strides, padding="SAME", name=name + "1_") # strides = 1 - nn = batchnorm_with_activation(nn, activation=activation, zero_gamma=False, name=name + "1_") - nn = attn_block(nn, filters, strides, attn_type, se_ratio / expansion, halo_block_size, True, activation, name=name + "2_") + nn = batchnorm_with_activation(nn, activation=activation, zero_gamma=False, name=name + "deep_1_") + nn = attn_block(nn, filters, strides, attn_type, se_ratio / expansion, HALO_BLOCK_SIZE, True, activation, name=name + "deep_2_") if expansion > 1: # not ResNet-RS like - nn = conv2d_no_bias(nn, expanded_filter, 1, strides=1, padding="VALID", name=name + "3_") + nn = conv2d_no_bias(nn, expanded_filter, 1, strides=1, padding="VALID", name=name + "deep_3_") + return nn + + +def block(inputs, filters, preact=False, strides=1, conv_shortcut=False, expansion=4, attn_type=None, se_ratio=0, drop_rate=0, activation="relu", name=""): + expanded_filter = filters * expansion + if attn_type == "halo" and inputs.shape[1] % HALO_BLOCK_SIZE != 0: # HaloAttention + gap = HALO_BLOCK_SIZE - inputs.shape[1] % HALO_BLOCK_SIZE + pad_head, pad_tail = gap // 2, gap - gap // 2 + inputs = keras.layers.ZeroPadding2D(padding=((pad_head, pad_tail), (pad_head, pad_tail)), name=name + "gap_pad")(inputs) - # print(">>>> shortcut:", shortcut.shape, "nn:", nn.shape) - if drop_rate > 0: - nn = keras.layers.Dropout(drop_rate, noise_shape=(None, 1, 1, 1), name=name + "drop")(nn) if preact: # ResNetV2 - return keras.layers.Add(name=name + "add")([shortcut, nn]) + pre_inputs = batchnorm_with_activation(inputs, activation=activation, zero_gamma=False, name=name + "preact_") else: - nn = batchnorm_with_activation(nn, activation=None, zero_gamma=True, name=name + "3_") - nn = keras.layers.Add(name=name + "add")([shortcut, nn]) - return keras.layers.Activation(activation, name=name + "out")(nn) + pre_inputs = inputs + + deep = deep_branch(pre_inputs, filters, strides, expansion, attn_type, se_ratio, activation=activation, name=name) + + if conv_shortcut: # Set a new shortcut using conv + shortcut = conv_shortcut_branch(pre_inputs, expanded_filter, preact, strides, avg_pool_down=True, anti_alias_down=False, name=name) + else: + shortcut = keras.layers.MaxPooling2D(strides, strides=strides, padding="SAME")(inputs) if strides > 1 else inputs + + # print(">>>> shortcut:", shortcut.shape, "deep:", deep.shape) + if preact: # ResNetV2 + deep = drop_block(deep, drop_rate) + return keras.layers.Add(name=name + "add")([shortcut, deep]) + else: + deep = batchnorm_with_activation(deep, activation=None, zero_gamma=True, name=name + "3_") + deep = drop_block(deep, drop_rate) + out = keras.layers.Add(name=name + "add")([shortcut, deep]) + return keras.layers.Activation(activation, name=name + "out")(out) def stack1(inputs, blocks, filters, preact=False, strides=2, expansion=4, attn_types=None, se_ratio=0, stack_drop=0, activation="relu", name=""): @@ -183,13 +223,19 @@ def stack2(inputs, blocks, filters, preact=True, strides=2, expansion=4, attn_ty return nn -def stem(inputs, stem_width, activation="relu", deep_stem=False, name=""): +def stem(inputs, stem_width, activation="relu", deep_stem=False, quad_stem=False, name=""): if deep_stem: nn = conv2d_no_bias(inputs, stem_width, 3, strides=2, padding="same", name=name + "1_") nn = batchnorm_with_activation(nn, activation=activation, name=name + "1_") nn = conv2d_no_bias(nn, stem_width, 3, strides=1, padding="same", name=name + "2_") nn = batchnorm_with_activation(nn, activation=activation, name=name + "2_") nn = conv2d_no_bias(nn, stem_width * 2, 3, strides=1, padding="same", name=name + "3_") + elif quad_stem: + nn = conv2d_no_bias(inputs, stem_width // 4, 3, strides=2, padding="same", name=name + "1_") + nn = conv2d_no_bias(nn, stem_width // 2, 3, strides=1, padding="same", name=name + "2_") + nn = conv2d_no_bias(nn, stem_width, 3, strides=1, padding="same", name=name + "3_") + nn = batchnorm_with_activation(nn, activation=activation, name=name + "1_") + nn = conv2d_no_bias(nn, stem_width * 2, 3, strides=2, padding="same", name=name + "4_") else: nn = conv2d_no_bias(inputs, stem_width, 7, strides=2, padding="same", name=name) return nn @@ -203,20 +249,22 @@ def AotNet( out_channels=[64, 128, 256, 512], stem_width=64, deep_stem=False, + quad_stem=False, stem_downsample=True, attn_types=None, expansion=4, se_ratio=0, # (0, 1) + num_features=0, input_shape=(224, 224, 3), num_classes=1000, activation="relu", drop_connect_rate=0, classifier_activation="softmax", model_name="aotnet", - **kwargs + kwargs=None ): inputs = keras.layers.Input(shape=input_shape) - nn = stem(inputs, stem_width, activation=activation, deep_stem=deep_stem, name="stem_") + nn = stem(inputs, stem_width, activation=activation, deep_stem=deep_stem, quad_stem=quad_stem, name="stem_") if not preact: nn = batchnorm_with_activation(nn, activation=activation, name="stem_") @@ -234,12 +282,17 @@ def AotNet( stack_drop = (stack_drop_s, stack_drop_e) attn_type = attn_types[id] if isinstance(attn_types, (list, tuple)) else attn_types cur_se_ratio = se_ratio[id] if isinstance(se_ratio, (list, tuple)) else se_ratio - nn = stack(nn, num_block, out_channel, preact, stride, expansion, attn_type, cur_se_ratio, stack_drop, activation, name=name) + cur_expansion = expansion[id] if isinstance(expansion, (list, tuple)) else expansion + nn = stack(nn, num_block, out_channel, preact, stride, cur_expansion, attn_type, cur_se_ratio, stack_drop, activation, name=name) global_block_id += num_block - if preact: + if preact: # resnetv2 like nn = batchnorm_with_activation(nn, activation=activation, zero_gamma=False, name="post_") + if num_features != 0: # efficientnet like + nn = conv2d_no_bias(nn, num_features, 1, strides=1, name="features_") + nn = batchnorm_with_activation(nn, activation=activation, name="features_") + if num_classes > 0: nn = keras.layers.GlobalAveragePooling2D(name="avg_pool")(nn) nn = keras.layers.Dense(num_classes, activation=classifier_activation, name="predictions")(nn) diff --git a/keras_cv_attention_models/mlp/README.md b/keras_cv_attention_models/mlp/README.md new file mode 100644 index 00000000..f718dcdf --- /dev/null +++ b/keras_cv_attention_models/mlp/README.md @@ -0,0 +1,98 @@ +# ___Keras MLP___ + + +- [Keras_mlp](#kerasmlp) + - [Usage](#usage) + - [MLP mixer](#mlp-mixer) + - [ResMLP](#resmlp) + - [GMLP](#gmlp) + + +*** + +## Usage + - **Basic usage** + ```py + from keras_cv_attention_models import mlp + # Will download and load `imagenet` pretrained weights. + # Model weight is loaded with `by_name=True, skip_mismatch=True`. + mm = mlp.MLPMixerB16(num_classes=1000, pretrained="imagenet") + + # Run prediction + import tensorflow as tf + from tensorflow import keras + from skimage.data import chelsea # Chelsea the cat + imm = keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='tf') # model="tf" or "torch" + pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy() + print(keras.applications.imagenet_utils.decode_predictions(pred)[0]) + # [('n02124075', 'Egyptian_cat', 0.9568315), ('n02123045', 'tabby', 0.017994137), ...] + ``` + For `"imagenet21k"` pre-trained models, actual `num_classes` is `21843`. + - **Exclude model top layers** by set `num_classes=0`. + ```py + from keras_cv_attention_models import mlp + mm = mlp.ResMLP_B24(num_classes=0, pretrained="imagenet22k") + print(mm.output_shape) + # (None, 784, 768) + + mm.save('resmlp_b24_imagenet22k-notop.h5') + ``` +## MLP mixer + - [PDF 2105.01601 MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/pdf/2105.01601.pdf). + - [Github google-research/vision_transformer](https://github.com/google-research/vision_transformer#available-mixer-models). + - **Models** `Top1 Acc` is `Pre-trained on JFT-300M` model accuray on `ImageNet 1K` from paper. + | Model | Params | Top1 Acc | ImageNet | Imagenet21k | ImageNet SAM | + | ----------- | ------ | -------- | --------------- | ------------------ | ------------------- | + | MLPMixerS32 | 19.1M | 68.70 | | | | + | MLPMixerS16 | 18.5M | 73.83 | | | | + | MLPMixerB32 | 60.3M | 75.53 | | | [b32_imagenet_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/mlp_mixer_b32_imagenet_sam.h5) | + | MLPMixerB16 | 59.9M | 80.00 | [b16_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/mlp_mixer_b16_imagenet.h5) | [b16_imagenet21k.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/mlp_mixer_b16_imagenet21k.h5) | [b16_imagenet_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/mlp_mixer_b16_imagenet_sam.h5) | + | MLPMixerL32 | 206.9M | 80.67 | | | | + | MLPMixerL16 | 208.2M | 84.82 | [l16_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/mlp_mixer_l16_imagenet.h5) | [l16_imagenet21k.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/mlp_mixer_l16_imagenet21k.h5) | | + | - input 448 | 208.2M | 86.78 | | | | + | MLPMixerH14 | 432.3M | 86.32 | | | | + | - input 448 | 432.3M | 87.94 | | | | + + | Specification | S/32 | S/16 | B/32 | B/16 | L/32 | L/16 | H/14 | + | -------------------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | + | Number of layers | 8 | 8 | 12 | 12 | 24 | 24 | 32 | + | Patch resolution P×P | 32×32 | 16×16 | 32×32 | 16×16 | 32×32 | 16×16 | 14×14 | + | Hidden size C | 512 | 512 | 768 | 768 | 1024 | 1024 | 1280 | + | Sequence length S | 49 | 196 | 49 | 196 | 49 | 196 | 256 | + | MLP dimension DC | 2048 | 2048 | 3072 | 3072 | 4096 | 4096 | 5120 | + | MLP dimension DS | 256 | 256 | 384 | 384 | 512 | 512 | 640 | + - Parameter `pretrained` is added in value `[None, "imagenet", "imagenet21k", "imagenet_sam"]`. Default is `imagenet`. + - **Pre-training details** + - We pre-train all models using Adam with β1 = 0.9, β2 = 0.999, and batch size 4 096, using weight decay, and gradient clipping at global norm 1. + - We use a linear learning rate warmup of 10k steps and linear decay. + - We pre-train all models at resolution 224. + - For JFT-300M, we pre-process images by applying the cropping technique from Szegedy et al. [44] in addition to random horizontal flipping. + - For ImageNet and ImageNet-21k, we employ additional data augmentation and regularization techniques. + - In particular, we use RandAugment [12], mixup [56], dropout [42], and stochastic depth [19]. + - This set of techniques was inspired by the timm library [52] and Touvron et al. [46]. + - More details on these hyperparameters are provided in Supplementary B. +## ResMLP + - [PDF 2105.03404 ResMLP: Feedforward networks for image classification with data-efficient training](https://arxiv.org/pdf/2105.03404.pdf) + - [Github facebookresearch/deit](https://github.com/facebookresearch/deit) + - **Models** reloaded `imagenet` weights are the `distilled` version from official. + | Model | Params | Image resolution | Top1 Acc | ImageNet | + | ---------- | ------ | ---------------- | -------- | -------- | + | ResMLP12 | 15M | 224 | 77.8 | [resmlp12_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/resmlp12_imagenet.h5) | + | ResMLP24 | 30M | 224 | 80.8 | [resmlp24_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/resmlp24_imagenet.h5) | + | ResMLP36 | 116M | 224 | 81.1 | [resmlp36_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/resmlp36_imagenet.h5) | + | ResMLP_B24 | 129M | 224 | 83.6 | [resmlp_b24_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/resmlp_b24_imagenet.h5) | + | - imagenet22k | 129M | 224 | 84.4 | [resmlp_b24_imagenet22k.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/resmlp_b24_imagenet22k.h5) | + + - Parameter `pretrained` is added in value `[None, "imagenet", "imagenet22k"]`, where `imagenet22k` means pre-trained on `imagenet21k` and fine-tuned on `imagenet`. Default is `imagenet`. +## GMLP + - [PDF 2105.08050 Pay Attention to MLPs](https://arxiv.org/pdf/2105.08050.pdf). + - Model weights reloaded from [Github timm/models/mlp_mixer](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mlp_mixer.py). + - **Models** + | Model | Params | Image resolution | Top1 Acc | ImageNet | + | ---------- | ------ | ---------------- | -------- | -------- | + | GMLPTiny16 | 6M | 224 | 72.3 | | + | GMLPS16 | 20M | 224 | 79.6 | [gmlp_s16_imagenet.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/gmlp_s16_imagenet.h5) | + | GMLPB16 | 73M | 224 | 81.6 | | + + - Parameter `pretrained` is added in value `[None, "imagenet"]`. Default is `imagenet`. +*** diff --git a/keras_cv_attention_models/mlp/__init__.py b/keras_cv_attention_models/mlp/__init__.py new file mode 100644 index 00000000..7794223d --- /dev/null +++ b/keras_cv_attention_models/mlp/__init__.py @@ -0,0 +1,146 @@ +from keras_cv_attention_models.mlp.mlp_mixer import MLPMixer, MLPMixerS32, MLPMixerS16, MLPMixerB32, MLPMixerB16, MLPMixerL32, MLPMixerL16, MLPMixerH14 +from keras_cv_attention_models.mlp.res_mlp import ResMLP, ResMLP12, ResMLP24, ResMLP36, ResMLP_B24 +from keras_cv_attention_models.mlp.gated_mlp import GMLP, GMLPTiny16, GMLPS16, GMLPB16 + +__mlp_mixer_head_doc__ = """ +Github source [leondgarse/keras_cv_attention_models](https://github.com/leondgarse/keras_cv_attention_models). +Keras implementation of [Github google-research/vision_transformer](https://github.com/google-research/vision_transformer#available-mixer-models). +Paper [PDF 2105.01601 MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/pdf/2105.01601.pdf). +""" + +__tail_doc__ = """ input_shape: it should have exactly 3 inputs channels like `(224, 224, 3)`. + num_classes: number of classes to classify images into. Set `0` to exclude top layers. + For `"imagenet21k"` pre-trained model, actual `num_classes` is `21843`. + activation: activation used in whole model, default `gelu`. + sam_rho: None zero value to init model using `SAM` training step. + SAM Arxiv article: [Sharpness-Aware Minimization for Efficiently Improving Generalization](https://arxiv.org/pdf/2010.01412.pdf). + dropout: dropout rate if top layers is included. + drop_connect_rate: is used for [Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382). + Can be a constant value like `0.2`, + or a tuple value like `(0, 0.2)` indicates the drop probability linearly changes from `0 --> 0.2` for `top --> bottom` layers. + A higher value means a higher probability will drop the deep branch. + or `0` to disable (default). + classifier_activation: A `str` or callable. The activation function to use on the `top` layer if `num_classes > 0`. + Set `classifier_activation=None` to return the logits of the `top` layer. + Default is `softmax`. + pretrained: value in {pretrained_list}. + Will try to download and load pre-trained model weights if not None. + Save path is `~/.keras/models/`. + +Returns: + A `keras.Model` instance. +""" + +MLPMixer.__doc__ = __mlp_mixer_head_doc__ + """ +Args: + num_blocks: number of layers. + patch_size: stem patch resolution P×P, means `kernel_size=patch_size, strides=patch_size` for stem `Conv2D` block. + stem_width: stem output channel dimenion. + tokens_mlp_dim: MLP block token level hidden dimenion, where token level means `height * weight` dimention. + channels_mlp_dim: MLP block channel level hidden dimenion. + model_name: string, model name. +""" + __tail_doc__.format(pretrained_list=[None, "imagenet", "imagenet21k", "imagenet_sam"]) + """ +Model architectures: + | Model | Params | Top1 Acc | Pre-trained | + | ----------- | ------ | -------- | ----------------------------------- | + | MLPMixerS32 | 19.1M | 68.70 | None | + | MLPMixerS16 | 18.5M | 73.83 | None | + | MLPMixerB32 | 60.3M | 75.53 | imagenet_sam | + | MLPMixerB16 | 59.9M | 80.00 | imagenet, imagenet_sam, imagenet21k | + | MLPMixerL32 | 206.9M | 80.67 | None | + | MLPMixerL16 | 208.2M | 84.82 | imagenet, imagenet21k | + | - input 448 | 208.2M | 86.78 | None | + | MLPMixerH14 | 432.3M | 86.32 | None | + | - input 448 | 432.3M | 87.94 | None | + + | Specification | S/32 | S/16 | B/32 | B/16 | L/32 | L/16 | H/14 | + | -------------------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | + | Number of layers | 8 | 8 | 12 | 12 | 24 | 24 | 32 | + | Patch resolution P×P | 32×32 | 16×16 | 32×32 | 16×16 | 32×32 | 16×16 | 14×14 | + | Hidden size C | 512 | 512 | 768 | 768 | 1024 | 1024 | 1280 | + | Sequence length S | 49 | 196 | 49 | 196 | 49 | 196 | 256 | + | MLP dimension DC | 2048 | 2048 | 3072 | 3072 | 4096 | 4096 | 5120 | + | MLP dimension DS | 256 | 256 | 384 | 384 | 512 | 512 | 640 | +""" + + +__mixer_default_doc__ = __mlp_mixer_head_doc__ + """ +[{model_name} architecture] num_blocks: {num_blocks}, patch_size: {patch_size}, stem_width: {stem_width}, tokens_mlp_dim: {tokens_mlp_dim}, channels_mlp_dim: {channels_mlp_dim}. + +Args: +""" + __tail_doc__.format(pretrained_list=[None, "imagenet", "imagenet21k", "imagenet_sam"]) + +MLPMixerS32.__doc__ = __mixer_default_doc__.format(model_name="MLPMixerS32", **mlp_mixer.BLOCK_CONFIGS["s32"]) +MLPMixerS16.__doc__ = __mixer_default_doc__.format(model_name="MLPMixerS16", **mlp_mixer.BLOCK_CONFIGS["s16"]) +MLPMixerB32.__doc__ = __mixer_default_doc__.format(model_name="MLPMixerB32", **mlp_mixer.BLOCK_CONFIGS["b32"]) +MLPMixerB16.__doc__ = __mixer_default_doc__.format(model_name="MLPMixerB16", **mlp_mixer.BLOCK_CONFIGS["b16"]) +MLPMixerL32.__doc__ = __mixer_default_doc__.format(model_name="MLPMixerL32", **mlp_mixer.BLOCK_CONFIGS["l32"]) +MLPMixerL16.__doc__ = __mixer_default_doc__.format(model_name="MLPMixerL16", **mlp_mixer.BLOCK_CONFIGS["l16"]) +MLPMixerH14.__doc__ = __mixer_default_doc__.format(model_name="MLPMixerH14", **mlp_mixer.BLOCK_CONFIGS["h14"]) + +__resmlp_head_doc__ = """ +Github source [leondgarse/keras_cv_attention_models](https://github.com/leondgarse/keras_cv_attention_models). +Keras implementation of [Github facebookresearch/deit](https://github.com/facebookresearch/deit). +Paper [PDF 2105.03404 ResMLP: Feedforward networks for image classification with data-efficient training](https://arxiv.org/pdf/2105.03404.pdf). +""" + +ResMLP.__doc__ = __resmlp_head_doc__ + """ +Args: + num_blocks: number of layers. + patch_size: stem patch resolution P×P, means `kernel_size=patch_size, strides=patch_size` for stem `Conv2D` block. + stem_width: stem output channel dimenion. + channels_mlp_dim: MLP block channel level hidden dimenion. + model_name: string, model name. +""" + __tail_doc__.format(pretrained_list=[None, "imagenet", "imagenet22k"]) + """ +Model architectures: + | Model | Params | Image resolution | Top1 Acc | Pre-trained | + | ------------- | ------ | ---------------- | -------- | ----------- | + | ResMLP12 | 15M | 224 | 77.8 | imagenet | + | ResMLP24 | 30M | 224 | 80.8 | imagenet | + | ResMLP36 | 116M | 224 | 81.1 | imagenet | + | ResMLP_B24 | 129M | 224 | 83.6 | imagenet | + | - imagenet22k | 129M | 224 | 84.4 | imagenet22k | +""" + +__resmlp_default_doc__ = __resmlp_head_doc__ + """ +[{model_name} architecture] num_blocks: {num_blocks}, patch_size: {patch_size}, stem_width: {stem_width}, channels_mlp_dim: {channels_mlp_dim}. + +Args: +""" + __tail_doc__.format(pretrained_list=[None, "imagenet", "imagenet22k"]) + +ResMLP12.__doc__ = __resmlp_default_doc__.format(model_name="ResMLP12", **res_mlp.BLOCK_CONFIGS["12"]) +ResMLP24.__doc__ = __resmlp_default_doc__.format(model_name="ResMLP24", **res_mlp.BLOCK_CONFIGS["24"]) +ResMLP36.__doc__ = __resmlp_default_doc__.format(model_name="ResMLP36", **res_mlp.BLOCK_CONFIGS["36"]) +ResMLP_B24.__doc__ = __resmlp_default_doc__.format(model_name="ResMLP_B24", **res_mlp.BLOCK_CONFIGS["b24"]) + +__gmlp_head_doc__ = """ +Github source [leondgarse/keras_cv_attention_models](https://github.com/leondgarse/keras_cv_attention_models). +Keras implementation of Gated MLP. +Paper [PDF 2105.08050 Pay Attention to MLPs](https://arxiv.org/pdf/2105.08050.pdf). +""" + +GMLP.__doc__ = __gmlp_head_doc__ + """ +Args: + num_blocks: number of layers. + patch_size: stem patch resolution P×P, means `kernel_size=patch_size, strides=patch_size` for stem `Conv2D` block. + stem_width: stem output channel dimenion. + channels_mlp_dim: MLP block channel level hidden dimenion. + model_name: string, model name. +""" + __tail_doc__.format(pretrained_list=[None, "imagenet"]) + """ +Model architectures: + | Model | Params | Image resolution | Top1 Acc | Pre-trained | + | ---------- | ------ | ---------------- | -------- | ----------- | + | GMLPTiny16 | 6M | 224 | 72.3 | None | + | GMLPS16 | 20M | 224 | 79.6 | imagenet | + | GMLPB16 | 73M | 224 | 81.6 | None | +""" + +__gmlp_default_doc__ = __gmlp_head_doc__ + """ +[{model_name} architecture] num_blocks: {num_blocks}, patch_size: {patch_size}, stem_width: {stem_width}, channels_mlp_dim: {channels_mlp_dim}. + +Args: +""" + __tail_doc__.format(pretrained_list=[None, "imagenet"]) + +GMLPTiny16.__doc__ = __gmlp_default_doc__.format(model_name="GMLPTiny16", **gated_mlp.BLOCK_CONFIGS["tiny16"]) +GMLPS16.__doc__ = __gmlp_default_doc__.format(model_name="GMLPS16", **gated_mlp.BLOCK_CONFIGS["s16"]) +GMLPB16.__doc__ = __gmlp_default_doc__.format(model_name="GMLPB16", **gated_mlp.BLOCK_CONFIGS["b16"]) diff --git a/keras_cv_attention_models/mlp/gated_mlp.py b/keras_cv_attention_models/mlp/gated_mlp.py new file mode 100644 index 00000000..cb84d5af --- /dev/null +++ b/keras_cv_attention_models/mlp/gated_mlp.py @@ -0,0 +1,136 @@ +from tensorflow import keras +from tensorflow.keras import backend as K +import tensorflow as tf +import os + +BATCH_NORM_EPSILON = 1e-5 + + +def layer_norm(inputs, name=None): + norm_axis = -1 if K.image_data_format() == "channels_last" else 1 + return keras.layers.LayerNormalization(axis=norm_axis, epsilon=BATCH_NORM_EPSILON, name=name)(inputs) + + +def res_gated_mlp_block(inputs, channels_mlp_dim, drop_rate=0, activation="gelu", name=None): + nn = layer_norm(inputs, name=name + "pre_ln") + nn = keras.layers.Dense(channels_mlp_dim, name=name + "pre_dense")(nn) + nn = keras.layers.Activation(activation, name=name + "gelu")(nn) + # Drop + + # SpatialGatingUnit + uu, vv = tf.split(nn, 2, axis=-1) + # print(f">>>> {uu.shape = }, {vv.shape = }") + vv = layer_norm(vv, name=name + "vv_ln") + vv = keras.layers.Permute((2, 1), name=name + "permute_1")(vv) + ww_init = keras.initializers.truncated_normal(stddev=1e-6) + vv = keras.layers.Dense(vv.shape[-1], kernel_initializer=ww_init, bias_initializer="ones", name=name + "vv_dense")(vv) + vv = keras.layers.Permute((2, 1), name=name + "permute_2")(vv) + # print(f">>>> {uu.shape = }, {vv.shape = }") + gated_out = keras.layers.Multiply()([uu, vv]) + + nn = keras.layers.Dense(inputs.shape[-1], name=name + "gated_dense")(gated_out) + # Drop + + # Drop path + if drop_rate > 0: + nn = keras.layers.Dropout(drop_rate, noise_shape=(None, 1, 1), name=name + "drop")(nn) + return keras.layers.Add(name=name + "out")([nn, inputs]) + + +def GMLP( + num_blocks, + patch_size, + stem_width, + channels_mlp_dim, + input_shape=(224, 224, 3), + num_classes=0, + activation="gelu", + sam_rho=0, + dropout=0, + drop_connect_rate=0, + classifier_activation="softmax", + pretrained="imagenet", + model_name="gmlp", + kwargs=None, +): + inputs = keras.Input(input_shape) + nn = keras.layers.Conv2D(stem_width, kernel_size=patch_size, strides=patch_size, padding="valid", name="stem")(inputs) + nn = keras.layers.Reshape([nn.shape[1] * nn.shape[2], stem_width])(nn) + + drop_connect_s, drop_connect_e = drop_connect_rate if isinstance(drop_connect_rate, (list, tuple)) else [drop_connect_rate, drop_connect_rate] + for ii in range(num_blocks): + name = "{}_{}_".format("gmlp", str(ii + 1)) + block_drop_rate = drop_connect_s + (drop_connect_e - drop_connect_s) * ii / num_blocks + nn = res_gated_mlp_block(nn, channels_mlp_dim=channels_mlp_dim, drop_rate=block_drop_rate, activation=activation, name=name) + nn = layer_norm(nn, name="pre_head_norm") + + if num_classes > 0: + # nn = tf.reduce_mean(nn, axis=1) + nn = keras.layers.GlobalAveragePooling1D()(nn) + if dropout > 0 and dropout < 1: + nn = keras.layers.Dropout(dropout)(nn) + nn = keras.layers.Dense(num_classes, activation=classifier_activation, name="predictions")(nn) + + if sam_rho != 0: + from keras_cv_attention_models.model_surgery import SAMModel + + model = SAMModel(inputs, nn, name=model_name) + else: + model = keras.Model(inputs, nn, name=model_name) + reload_model_weights(model, input_shape, pretrained) + return model + + +def reload_model_weights(model, input_shape=(224, 224, 3), pretrained="imagenet"): + pretrained_dd = { + "gmlp_s16": ["imagenet"], + } + if model.name not in pretrained_dd or pretrained not in pretrained_dd[model.name]: + print(">>>> No pretraind available, model will be randomly initialized") + return + + pre_url = "https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/{}_{}.h5" + url = pre_url.format(model.name, pretrained) + file_name = os.path.basename(url) + try: + pretrained_model = keras.utils.get_file(file_name, url, cache_subdir="models") + except: + print("[Error] will not load weights, url not found or download failed:", url) + return + else: + print(">>>> Load pretraind from:", pretrained_model) + model.load_weights(pretrained_model, by_name=True, skip_mismatch=True) + + +BLOCK_CONFIGS = { + "tiny16": { + "num_blocks": 30, + "patch_size": 16, + "stem_width": 128, + "channels_mlp_dim": 128 * 6, + }, + "s16": { + "num_blocks": 30, + "patch_size": 16, + "stem_width": 256, + "channels_mlp_dim": 256 * 6, + }, + "b16": { + "num_blocks": 30, + "patch_size": 16, + "stem_width": 512, + "channels_mlp_dim": 512 * 6, + }, +} + + +def GMLPTiny16(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return GMLP(**BLOCK_CONFIGS["tiny16"], **locals(), model_name="gmlp_tiny16", **kwargs) + + +def GMLPS16(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return GMLP(**BLOCK_CONFIGS["s16"], **locals(), model_name="gmlp_s16", **kwargs) + + +def GMLPB16(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return GMLP(**BLOCK_CONFIGS["b16"], **locals(), model_name="gmlp_b16", **kwargs) diff --git a/keras_cv_attention_models/mlp/mlp_mixer.py b/keras_cv_attention_models/mlp/mlp_mixer.py new file mode 100644 index 00000000..35017847 --- /dev/null +++ b/keras_cv_attention_models/mlp/mlp_mixer.py @@ -0,0 +1,206 @@ +from tensorflow import keras +from tensorflow.keras import backend as K +import os + +BATCH_NORM_EPSILON = 1e-5 + + +def layer_norm(inputs, name=None): + norm_axis = -1 if K.image_data_format() == "channels_last" else 1 + return keras.layers.LayerNormalization(axis=norm_axis, epsilon=BATCH_NORM_EPSILON, name=name)(inputs) + + +def mlp_block(inputs, hidden_dim, activation="gelu", name=None): + nn = keras.layers.Dense(hidden_dim, name=name + "Dense_0")(inputs) + nn = keras.layers.Activation(activation, name=name + "gelu")(nn) + nn = keras.layers.Dense(inputs.shape[-1], name=name + "Dense_1")(nn) + return nn + + +def mixer_block(inputs, tokens_mlp_dim, channels_mlp_dim, drop_rate=0, activation="gelu", name=None): + nn = layer_norm(inputs, name=name + "LayerNorm_0") + nn = keras.layers.Permute((2, 1), name=name + "permute_0")(nn) + nn = mlp_block(nn, tokens_mlp_dim, activation, name=name + "token_mixing/") + nn = keras.layers.Permute((2, 1), name=name + "permute_1")(nn) + if drop_rate > 0: + nn = keras.layers.Dropout(drop_rate, noise_shape=(None, 1, 1), name=name + "token_drop")(nn) + token_out = keras.layers.Add(name=name + "add_0")([nn, inputs]) + + nn = layer_norm(token_out, name=name + "LayerNorm_1") + channel_out = mlp_block(nn, channels_mlp_dim, activation, name=name + "channel_mixing/") + if drop_rate > 0: + channel_out = keras.layers.Dropout(drop_rate, noise_shape=(None, 1, 1), name=name + "channel_drop")(channel_out) + return keras.layers.Add(name=name + "add_1")([channel_out, token_out]) + + +def MLPMixer( + num_blocks, + patch_size, + stem_width, + tokens_mlp_dim, + channels_mlp_dim, + input_shape=(224, 224, 3), + num_classes=0, + activation="gelu", + sam_rho=0, + dropout=0, + drop_connect_rate=0, + classifier_activation="softmax", + pretrained="imagenet", + model_name="mlp_mixer", + kwargs=None, +): + inputs = keras.Input(input_shape) + nn = keras.layers.Conv2D(stem_width, kernel_size=patch_size, strides=patch_size, padding="same", name="stem")(inputs) + nn = keras.layers.Reshape([nn.shape[1] * nn.shape[2], stem_width])(nn) + + drop_connect_s, drop_connect_e = drop_connect_rate if isinstance(drop_connect_rate, (list, tuple)) else [drop_connect_rate, drop_connect_rate] + for ii in range(num_blocks): + name = "{}_{}/".format("MixerBlock", str(ii)) + block_drop_rate = drop_connect_s + (drop_connect_e - drop_connect_s) * ii / num_blocks + nn = mixer_block(nn, tokens_mlp_dim, channels_mlp_dim, drop_rate=block_drop_rate, activation=activation, name=name) + nn = layer_norm(nn, name="pre_head_layer_norm") + + if num_classes > 0: + nn = keras.layers.GlobalAveragePooling1D()(nn) # tf.reduce_mean(nn, axis=1) + if dropout > 0 and dropout < 1: + nn = keras.layers.Dropout(dropout)(nn) + nn = keras.layers.Dense(num_classes, activation=classifier_activation, name="head")(nn) + + if sam_rho != 0: + from keras_cv_attention_models.model_surgery import SAMModel + + model = SAMModel(inputs, nn, name=model_name) + else: + model = keras.Model(inputs, nn, name=model_name) + reload_model_weights(model, input_shape, pretrained) + return model + + +def reload_model_weights(model, input_shape=(224, 224, 3), pretrained="imagenet"): + pretrained_dd = { + "mlp_mixer_b16": ["imagenet", "imagenet_sam", "imagenet21k"], + "mlp_mixer_l16": ["imagenet", "imagenet21k"], + "mlp_mixer_b32": ["imagenet_sam"], + } + if model.name not in pretrained_dd or pretrained not in pretrained_dd[model.name]: + print(">>>> No pretraind available, model will be randomly initialized") + return + + pre_url = "https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/{}_{}.h5" + url = pre_url.format(model.name, pretrained) + file_name = os.path.basename(url) + try: + pretrained_model = keras.utils.get_file(file_name, url, cache_subdir="models") + except: + print("[Error] will not load weights, url not found or download failed:", url) + return + else: + print(">>>> Load pretraind from:", pretrained_model) + model.load_weights(pretrained_model, by_name=True, skip_mismatch=True) + + +BLOCK_CONFIGS = { + "s32": { + "num_blocks": 8, + "patch_size": 32, + "stem_width": 512, + "tokens_mlp_dim": 256, + "channels_mlp_dim": 2048, + }, + "s16": { + "num_blocks": 8, + "patch_size": 16, + "stem_width": 512, + "tokens_mlp_dim": 256, + "channels_mlp_dim": 2048, + }, + "b32": { + "num_blocks": 12, + "patch_size": 32, + "stem_width": 768, + "tokens_mlp_dim": 384, + "channels_mlp_dim": 3072, + }, + "b16": { + "num_blocks": 12, + "patch_size": 16, + "stem_width": 768, + "tokens_mlp_dim": 384, + "channels_mlp_dim": 3072, + }, + "l32": { + "num_blocks": 24, + "patch_size": 32, + "stem_width": 1024, + "tokens_mlp_dim": 512, + "channels_mlp_dim": 4096, + }, + "l16": { + "num_blocks": 24, + "patch_size": 16, + "stem_width": 1024, + "tokens_mlp_dim": 512, + "channels_mlp_dim": 4096, + }, + "h14": { + "num_blocks": 32, + "patch_size": 14, + "stem_width": 1280, + "tokens_mlp_dim": 640, + "channels_mlp_dim": 5120, + }, +} + + +def MLPMixerS32(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return MLPMixer(**BLOCK_CONFIGS["s32"], **locals(), model_name="mlp_mixer_s32", **kwargs) + + +def MLPMixerS16(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return MLPMixer(**BLOCK_CONFIGS["s16"], **locals(), model_name="mlp_mixer_s16", **kwargs) + + +def MLPMixerB32(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return MLPMixer(**BLOCK_CONFIGS["b32"], **locals(), model_name="mlp_mixer_b32", **kwargs) + + +def MLPMixerB16(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return MLPMixer(**BLOCK_CONFIGS["b16"], **locals(), model_name="mlp_mixer_b16", **kwargs) + + +def MLPMixerL32(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return MLPMixer(**BLOCK_CONFIGS["l32"], **locals(), model_name="mlp_mixer_l32", **kwargs) + + +def MLPMixerL16(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return MLPMixer(**BLOCK_CONFIGS["l16"], **locals(), model_name="mlp_mixer_l16", **kwargs) + + +def MLPMixerH14(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return MLPMixer(**BLOCK_CONFIGS["h14"], **locals(), model_name="mlp_mixer_h14", **kwargs) + + +if __name__ == "__convert__": + aa = np.load("../models/imagenet1k_Mixer-B_16.npz") + bb = {kk: vv for kk, vv in aa.items()} + # cc = {kk: vv.shape for kk, vv in bb.items()} + + import mlp_mixer + + mm = mlp_mixer.MLPMixerB16(num_classes=1000, pretrained=None) + # dd = {ii.name: ii.shape for ii in mm.weights} + + target_weights_dict = {"kernel": 0, "bias": 1, "scale": 0, "running_var": 3} + for kk, vv in bb.items(): + split_name = kk.split("/") + source_name = "/".join(split_name[:-1]) + source_weight_type = split_name[-1] + target_layer = mm.get_layer(source_name) + + target_weights = target_layer.get_weights() + target_weight_pos = target_weights_dict[source_weight_type] + print("[{}] source: {}, target: {}".format(kk, vv.shape, target_weights[target_weight_pos].shape)) + + target_weights[target_weight_pos] = vv + target_layer.set_weights(target_weights) diff --git a/keras_cv_attention_models/mlp/res_mlp.py b/keras_cv_attention_models/mlp/res_mlp.py new file mode 100644 index 00000000..96991caf --- /dev/null +++ b/keras_cv_attention_models/mlp/res_mlp.py @@ -0,0 +1,170 @@ +from tensorflow import keras +import os + + +@keras.utils.register_keras_serializable(package="resmlp") +class ChannelAffine(keras.layers.Layer): + def __init__(self, use_bias=True, weight_init_value=1, **kwargs): + super(ChannelAffine, self).__init__(**kwargs) + self.use_bias, self.weight_init_value = use_bias, weight_init_value + self.ww_init = keras.initializers.Constant(weight_init_value) if weight_init_value != 1 else "ones" + self.bb_init = "zeros" + self.supports_masking = False + + def build(self, input_shape): + self.ww = self.add_weight(name="weight", shape=(input_shape[-1]), initializer=self.ww_init, trainable=True) + if self.use_bias: + self.bb = self.add_weight(name="bias", shape=(input_shape[-1]), initializer=self.bb_init, trainable=True) + super(ChannelAffine, self).build(input_shape) + + def call(self, inputs, **kwargs): + return inputs * self.ww + self.bb if self.use_bias else inputs * self.ww + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = super(ChannelAffine, self).get_config() + config.update({"use_bias": self.use_bias, "weight_init_value": self.weight_init_value}) + return config + + +# NOT using +def channel_affine(inputs, use_bias=True, weight_init_value=1, name=""): + ww_init = keras.initializers.Constant(weight_init_value) if weight_init_value != 1 else "ones" + nn = keras.backend.expand_dims(inputs, 1) + nn = keras.layers.DepthwiseConv2D(1, depthwise_initializer=ww_init, use_bias=use_bias, name=name)(nn) + return keras.backend.squeeze(nn, 1) + + +def res_mlp_block(inputs, channels_mlp_dim, drop_rate=0, activation="gelu", name=None): + nn = ChannelAffine(use_bias=True, name=name + "norm_1")(inputs) + nn = keras.layers.Permute((2, 1), name=name + "permute_1")(nn) + nn = keras.layers.Dense(nn.shape[-1], name=name + "token_mixing")(nn) + nn = keras.layers.Permute((2, 1), name=name + "permute_2")(nn) + nn = ChannelAffine(use_bias=False, name=name + "gamma_1")(nn) + if drop_rate > 0: + nn = keras.layers.Dropout(drop_rate, noise_shape=(None, 1, 1), name=name + "token_drop")(nn) + token_out = keras.layers.Add(name=name + "add_1")([inputs, nn]) + + nn = ChannelAffine(use_bias=True, name=name + "norm_2")(token_out) + nn = keras.layers.Dense(channels_mlp_dim, name=name + "channel_mixing_1")(nn) + nn = keras.layers.Activation(activation, name=name + activation)(nn) + nn = keras.layers.Dense(inputs.shape[-1], name=name + "channel_mixing_2")(nn) + channel_out = ChannelAffine(use_bias=False, name=name + "gamma_2")(nn) + if drop_rate > 0: + channel_out = keras.layers.Dropout(drop_rate, noise_shape=(None, 1, 1), name=name + "channel_drop")(channel_out) + nn = keras.layers.Add(name=name + "add_2")([channel_out, token_out]) + return nn + + +def ResMLP( + num_blocks, + patch_size, + stem_width, + channels_mlp_dim, + input_shape=(224, 224, 3), + num_classes=0, + activation="gelu", + sam_rho=0, + dropout=0, + drop_connect_rate=0, + classifier_activation="softmax", + pretrained="imagenet", + model_name="resmlp", + kwargs=None, +): + inputs = keras.Input(input_shape) + nn = keras.layers.Conv2D(stem_width, kernel_size=patch_size, strides=patch_size, padding="valid", name="stem")(inputs) + nn = keras.layers.Reshape([nn.shape[1] * nn.shape[2], stem_width])(nn) + + drop_connect_s, drop_connect_e = drop_connect_rate if isinstance(drop_connect_rate, (list, tuple)) else [drop_connect_rate, drop_connect_rate] + for ii in range(num_blocks): + name = "{}_{}_".format("ResMlpBlock", str(ii + 1)) + block_drop_rate = drop_connect_s + (drop_connect_e - drop_connect_s) * ii / num_blocks + nn = res_mlp_block(nn, channels_mlp_dim=channels_mlp_dim, drop_rate=block_drop_rate, activation=activation, name=name) + nn = ChannelAffine(name="pre_head_norm")(nn) + + if num_classes > 0: + # nn = tf.reduce_mean(nn, axis=1) + nn = keras.layers.GlobalAveragePooling1D()(nn) + if dropout > 0 and dropout < 1: + nn = keras.layers.Dropout(dropout)(nn) + nn = keras.layers.Dense(num_classes, activation=classifier_activation, name="predictions")(nn) + + if sam_rho != 0: + from keras_cv_attention_models.model_surgery import SAMModel + + model = SAMModel(inputs, nn, name=model_name) + else: + model = keras.Model(inputs, nn, name=model_name) + reload_model_weights(model, input_shape, pretrained) + return model + + +def reload_model_weights(model, input_shape=(224, 224, 3), pretrained="imagenet"): + pretrained_dd = { + "resmlp12": ["imagenet"], + "resmlp24": ["imagenet"], + "resmlp36": ["imagenet"], + "resmlp_b24": ["imagenet", "imagenet22k"], + } + if model.name not in pretrained_dd or pretrained not in pretrained_dd[model.name]: + print(">>>> No pretraind available, model will be randomly initialized") + return + + pre_url = "https://github.com/leondgarse/keras_cv_attention_models/releases/download/mlp/{}_{}.h5" + url = pre_url.format(model.name, pretrained) + file_name = os.path.basename(url) + try: + pretrained_model = keras.utils.get_file(file_name, url, cache_subdir="models") + except: + print("[Error] will not load weights, url not found or download failed:", url) + return + else: + print(">>>> Load pretraind from:", pretrained_model) + model.load_weights(pretrained_model, by_name=True, skip_mismatch=True) + + +BLOCK_CONFIGS = { + "12": { + "num_blocks": 12, + "patch_size": 16, + "stem_width": 384, + "channels_mlp_dim": 384 * 4, + }, + "24": { + "num_blocks": 24, + "patch_size": 16, + "stem_width": 384, + "channels_mlp_dim": 384 * 4, + }, + "36": { + "num_blocks": 36, + "patch_size": 16, + "stem_width": 384, + "channels_mlp_dim": 384 * 4, + }, + "b24": { + "num_blocks": 24, + "patch_size": 8, + "stem_width": 768, + "channels_mlp_dim": 768 * 4, + }, +} + + +def ResMLP12(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return ResMLP(**BLOCK_CONFIGS["12"], **locals(), model_name="resmlp12", **kwargs) + + +def ResMLP24(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return ResMLP(**BLOCK_CONFIGS["24"], **locals(), model_name="resmlp24", **kwargs) + + +def ResMLP36(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return ResMLP(**BLOCK_CONFIGS["36"], **locals(), model_name="resmlp36", **kwargs) + + +def ResMLP_B24(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + return ResMLP(**BLOCK_CONFIGS["b24"], **locals(), model_name="resmlp_b24", **kwargs) diff --git a/keras_cv_attention_models/version.py b/keras_cv_attention_models/version.py new file mode 100644 index 00000000..cd7ca498 --- /dev/null +++ b/keras_cv_attention_models/version.py @@ -0,0 +1 @@ +__version__ = '1.0.1' diff --git a/setup.py b/setup.py index a058c898..f301f89b 100644 --- a/setup.py +++ b/setup.py @@ -1,20 +1,49 @@ -from setuptools import find_packages -from setuptools import setup +""" Setup +""" +from setuptools import setup, find_packages +from codecs import open +from os import path +here = path.abspath(path.dirname(__file__)) + +# Get the long description from the README file +with open(path.join(here, 'README.md'), encoding='utf-8') as f: + long_description = f.read() + +exec(open('keras_cv_attention_models/version.py').read()) setup( name="keras-cv-attention-models", - version="1.0.0", + version=__version__, + description="tensorflow keras computer vision attention models", + long_description=long_description, + long_description_content_type='text/markdown', + url="https://github.com/leondgarse/keras_cv_attention_models", author="Leondgarse", author_email="leondgarse@google.com", - url="https://github.com/leondgarse/keras_cv_attention_models", - description="keras attention models", - long_description=open('README.md').read(), - long_description_content_type='text/markdown', - install_requires=[ - "einops", - "tensorflow", - "tensorflow-addons", + classifiers=[ + # How mature is this project? Common values are + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', ], - packages=find_packages(), + + # Note that this is a string of words separated by whitespace, not a list. + keywords='tensorflow keras cv attension pretrained models', + packages=find_packages(exclude=['tests']), + include_package_data=True, + install_requires=["tensorflow", "tensorflow-addons", "einops"], + python_requires='>=3.6', license="Apache 2.0", )