Skip to content

Commit

Permalink
update resnet_family
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Aug 23, 2021
1 parent 5f5d038 commit c4dcf75
Show file tree
Hide file tree
Showing 6 changed files with 402 additions and 3 deletions.
2 changes: 1 addition & 1 deletion keras_cv_attention_models/attention_layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from keras_cv_attention_models.coat.coat import ConvPositionalEncoding, ConvRelativePositionalEncoding, layer_norm
from keras_cv_attention_models.halonet.halonet import HaloAttention
from keras_cv_attention_models.resnest.resnest import rsoftmax, split_attention_conv2d
from keras_cv_attention_models.resnext.resnext import groups_depthwise
from keras_cv_attention_models.resnet_family.resnext import groups_depthwise
from keras_cv_attention_models.volo.volo import outlook_attention, outlook_attention_simple, BiasLayer, PositionalEmbedding, ClassToken
from keras_cv_attention_models.mlp_family.mlp_mixer import mlp_block, mixer_block
from keras_cv_attention_models.mlp_family.res_mlp import ChannelAffine
Expand Down
43 changes: 43 additions & 0 deletions keras_cv_attention_models/resnet_family/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# ___Keras ResNet Family___
***

## Summary
- Keras implementation of [Github facebookresearch/ResNeXt](https://github.com/facebookresearch/ResNeXt). Paper [PDF 1611.05431 Aggregated Residual Transformations for Deep Neural Networks](https://arxiv.org/pdf/1611.05431.pdf).
- Model weights reloaded from [Tensorflow keras/applications](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py).
***

## Models
| Model | Params | Image resolution | Top1 Acc | Download |
| --------------------- | ------ | ----------------- | -------- | ------------------- |
| resnext50 (32x4d) | 25M | 224 | 77.74 | [resnext50.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/resnext/resnext50.h5) |
| resnext101 (32x4d) | 42M | 224 | 78.73 | [resnext101.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/resnext/resnext101.h5) |
## Usage
```py
from keras_cv_attention_models import resnext

# Will download and load pretrained imagenet weights.
mm = resnext.ResNeXt50(pretrained="imagenet")

# Run prediction
from skimage.data import chelsea
imm = keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='tf') # Chelsea the cat
pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy()
print(keras.applications.imagenet_utils.decode_predictions(pred)[0])
# [('n02124075', 'Egyptian_cat', 0.98292357),
# ('n02123045', 'tabby', 0.009655442),
# ('n02123159', 'tiger_cat', 0.0057404325),
# ('n02127052', 'lynx', 0.00089362176),
# ('n04209239', 'shower_curtain', 0.00013918217)]
```
**Set new input resolution**
```py
from keras_cv_attention_models import resnext
mm = resnext.ResNeXt101(input_shape=(320, 320, 3), num_classes=0)
print(mm(np.ones([1, 320, 320, 3])).shape)
# (1, 10, 10, 2048)

mm = resnext.ResNeXt101(input_shape=(512, 512, 3), num_classes=0)
print(mm(np.ones([1, 512, 512, 3])).shape)
# (1, 16, 16, 2048)
```
***
88 changes: 88 additions & 0 deletions keras_cv_attention_models/resnet_family/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from keras_cv_attention_models.resnet_family.resnext import ResNeXt, ResNeXt50, ResNeXt101, groups_depthwise


__head_doc__ = """
Keras implementation of [Github facebookresearch/ResNeXt](https://github.com/facebookresearch/ResNeXt).
Paper [PDF 1611.05431 Aggregated Residual Transformations for Deep Neural Networks](https://arxiv.org/pdf/1611.05431.pdf).
"""

__tail_doc__ = """ strides: a `number` or `list`, indicates strides used in the last stack or list value for all stacks.
If a number, it will be `[1, 2, 2, strides]`.
out_channels: default as `[128, 256, 512, 1024]`. Output channel for each stack.
stem_width: output dimension for stem block.
deep_stem: Boolean value if use deep stem.
stem_downsample: Boolean value if ass `MaxPooling2D` layer after stem block.
cardinality: Control channel expansion in each block, the bigger the widder.
Also the `groups` number for `groups_depthwise` in each block, bigger `cardinality` leads to less `groups`.
input_shape: it should have exactly 3 inputs channels, default `(224, 224, 3)`.
num_classes: number of classes to classify images into. Set `0` to exclude top layers.
activation: activation used in whole model, default `relu`.
classifier_activation: A `str` or callable. The activation function to use on the "top" layer if `num_classes > 0`.
Set `classifier_activation=None` to return the logits of the "top" layer.
Default is `softmax`.
pretrained: one of `None` (random initialization) or 'imagenet' (pre-training on ImageNet).
Will try to download and load pre-trained model weights if not None.
**kwargs: other parameters if available.
Returns:
A `keras.Model` instance.
"""

ResNeXt.__doc__ = __head_doc__ + """
Args:
num_blocks: number of blocks in each stack.
model_name: string, model name.
""" + __tail_doc__ + """
Model architectures:
| Model | Params | Image resolution | Top1 Acc |
| -------------- | ------ | ----------------- | -------- |
| resnext50 | 25M | 224 | 77.8 |
| resnext101 | 42M | 224 | 80.9 |
"""

ResNeXt50.__doc__ = __head_doc__ + """
Args:
""" + __tail_doc__

ResNeXt101.__doc__ = ResNeXt50.__doc__

groups_depthwise.__doc__ = __head_doc__ + """
Grouped depthwise. Callable function, NOT defined as a layer.
Args:
inputs: input tensor.
groups: number of groups splitted for `DepthwiseConv2D` result.
kernel_size: kernel size for `DepthwiseConv2D`.
strides: strides for `DepthwiseConv2D`.
padding: padding for `DepthwiseConv2D`.
Examples:
>>> from keras_cv_attention_models import attention_layers
>>> inputs = keras.layers.Input([28, 28, 192])
>>> nn = attention_layers.groups_depthwise(inputs, groups=32)
>>> dd = keras.models.Model(inputs, nn)
>>> dd.output_shape
(None, 28, 28, 192)
>>> dd.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 28, 28, 192)] 0
_________________________________________________________________
zero_padding2d (ZeroPadding2 (None, 30, 30, 192) 0
_________________________________________________________________
depthwise_conv2d (DepthwiseC (None, 28, 28, 1152) 10368
_________________________________________________________________
reshape (Reshape) (None, 28, 28, 32, 6, 6) 0
_________________________________________________________________
tf.math.reduce_sum (TFOpLamb (None, 28, 28, 32, 6) 0
_________________________________________________________________
reshape_1 (Reshape) (None, 28, 28, 192) 0
=================================================================
Total params: 10,368
Trainable params: 10,368
Non-trainable params: 0
_________________________________________________________________
"""
49 changes: 49 additions & 0 deletions keras_cv_attention_models/resnet_family/resnet_deep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from keras_cv_attention_models.aotnet import AotNet
import os

def ResNetD(num_blocks, input_shape=(224, 224, 3), pretrained="imagenet", deep_stem=True, stem_width=32, strides=2, **kwargs):
strides = strides if isinstance(strides, (list, tuple)) else [1, 2, 2, strides]
model = AotNet(num_blocks, input_shape=input_shape, deep_stem=deep_stem, stem_width=stem_width, strides=strides, **kwargs)
reload_model_weights(model, input_shape, pretrained)
return model


def reload_model_weights(model, input_shape=(224, 224, 3), pretrained="imagenet"):
pretrained_dd = {
"resnet50d": ["imagenet"],
}
if model.name not in pretrained_dd or pretrained not in pretrained_dd[model.name]:
print(">>>> No pretraind available, model will be randomly initialized")
return

pre_url = "https://github.com/leondgarse/keras_cv_attention_models/releases/download/resnet_family/{}_{}.h5"
url = pre_url.format(model.name, pretrained)
file_name = os.path.basename(url)
try:
pretrained_model = keras.utils.get_file(file_name, url, cache_subdir="models")
except:
print("[Error] will not load weights, url not found or download failed:", url)
return
else:
print(">>>> Load pretraind from:", pretrained_model)
model.load_weights(pretrained_model, by_name=True, skip_mismatch=True)


def ResNet50D(input_shape=(224, 224, 3), num_classes=1000, activation="relu", classifier_activation="softmax", pretrained="imagenet", **kwargs):
num_blocks = [3, 4, 6, 3]
return ResNetD(**locals(), model_name="resnet50d", **kwargs)


def ResNet101D(input_shape=(224, 224, 3), num_classes=1000, activation="relu", classifier_activation="softmax", pretrained="imagenet", **kwargs):
num_blocks = [3, 4, 23, 3]
return ResNetD(**locals(), model_name="resnet101d", **kwargs)


def ResNet152D(input_shape=(224, 224, 3), num_classes=1000, activation="relu", classifier_activation="softmax", pretrained="imagenet", **kwargs):
num_blocks = [3, 8, 36, 3]
return ResNetD(**locals(), model_name="resnet152d", **kwargs)


def ResNet200D(input_shape=(224, 224, 3), num_classes=1000, activation="relu", classifier_activation="softmax", pretrained="imagenet", **kwargs):
num_blocks = [3, 24, 36, 3]
return ResNetD(**locals(), model_name="resnet200d", **kwargs)
Loading

0 comments on commit c4dcf75

Please sign in to comment.