diff --git a/keras_cv_attention_models/attention_layers/__init__.py b/keras_cv_attention_models/attention_layers/__init__.py index 2bd62ed7..ef040563 100644 --- a/keras_cv_attention_models/attention_layers/__init__.py +++ b/keras_cv_attention_models/attention_layers/__init__.py @@ -4,7 +4,7 @@ from keras_cv_attention_models.coat.coat import ConvPositionalEncoding, ConvRelativePositionalEncoding, layer_norm from keras_cv_attention_models.halonet.halonet import HaloAttention from keras_cv_attention_models.resnest.resnest import rsoftmax, split_attention_conv2d -from keras_cv_attention_models.resnext.resnext import groups_depthwise +from keras_cv_attention_models.resnet_family.resnext import groups_depthwise from keras_cv_attention_models.volo.volo import outlook_attention, outlook_attention_simple, BiasLayer, PositionalEmbedding, ClassToken from keras_cv_attention_models.mlp_family.mlp_mixer import mlp_block, mixer_block from keras_cv_attention_models.mlp_family.res_mlp import ChannelAffine diff --git a/keras_cv_attention_models/resnet_family/README.md b/keras_cv_attention_models/resnet_family/README.md new file mode 100644 index 00000000..0e2183a0 --- /dev/null +++ b/keras_cv_attention_models/resnet_family/README.md @@ -0,0 +1,43 @@ +# ___Keras ResNet Family___ +*** + +## Summary + - Keras implementation of [Github facebookresearch/ResNeXt](https://github.com/facebookresearch/ResNeXt). Paper [PDF 1611.05431 Aggregated Residual Transformations for Deep Neural Networks](https://arxiv.org/pdf/1611.05431.pdf). + - Model weights reloaded from [Tensorflow keras/applications](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py). +*** + +## Models + | Model | Params | Image resolution | Top1 Acc | Download | + | --------------------- | ------ | ----------------- | -------- | ------------------- | + | resnext50 (32x4d) | 25M | 224 | 77.74 | [resnext50.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/resnext/resnext50.h5) | + | resnext101 (32x4d) | 42M | 224 | 78.73 | [resnext101.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/resnext/resnext101.h5) | +## Usage + ```py + from keras_cv_attention_models import resnext + + # Will download and load pretrained imagenet weights. + mm = resnext.ResNeXt50(pretrained="imagenet") + + # Run prediction + from skimage.data import chelsea + imm = keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='tf') # Chelsea the cat + pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy() + print(keras.applications.imagenet_utils.decode_predictions(pred)[0]) + # [('n02124075', 'Egyptian_cat', 0.98292357), + # ('n02123045', 'tabby', 0.009655442), + # ('n02123159', 'tiger_cat', 0.0057404325), + # ('n02127052', 'lynx', 0.00089362176), + # ('n04209239', 'shower_curtain', 0.00013918217)] + ``` + **Set new input resolution** + ```py + from keras_cv_attention_models import resnext + mm = resnext.ResNeXt101(input_shape=(320, 320, 3), num_classes=0) + print(mm(np.ones([1, 320, 320, 3])).shape) + # (1, 10, 10, 2048) + + mm = resnext.ResNeXt101(input_shape=(512, 512, 3), num_classes=0) + print(mm(np.ones([1, 512, 512, 3])).shape) + # (1, 16, 16, 2048) + ``` +*** diff --git a/keras_cv_attention_models/resnet_family/__init__.py b/keras_cv_attention_models/resnet_family/__init__.py new file mode 100644 index 00000000..76644a49 --- /dev/null +++ b/keras_cv_attention_models/resnet_family/__init__.py @@ -0,0 +1,88 @@ +from keras_cv_attention_models.resnet_family.resnext import ResNeXt, ResNeXt50, ResNeXt101, groups_depthwise + + +__head_doc__ = """ +Keras implementation of [Github facebookresearch/ResNeXt](https://github.com/facebookresearch/ResNeXt). +Paper [PDF 1611.05431 Aggregated Residual Transformations for Deep Neural Networks](https://arxiv.org/pdf/1611.05431.pdf). +""" + +__tail_doc__ = """ strides: a `number` or `list`, indicates strides used in the last stack or list value for all stacks. + If a number, it will be `[1, 2, 2, strides]`. + out_channels: default as `[128, 256, 512, 1024]`. Output channel for each stack. + stem_width: output dimension for stem block. + deep_stem: Boolean value if use deep stem. + stem_downsample: Boolean value if ass `MaxPooling2D` layer after stem block. + cardinality: Control channel expansion in each block, the bigger the widder. + Also the `groups` number for `groups_depthwise` in each block, bigger `cardinality` leads to less `groups`. + input_shape: it should have exactly 3 inputs channels, default `(224, 224, 3)`. + num_classes: number of classes to classify images into. Set `0` to exclude top layers. + activation: activation used in whole model, default `relu`. + classifier_activation: A `str` or callable. The activation function to use on the "top" layer if `num_classes > 0`. + Set `classifier_activation=None` to return the logits of the "top" layer. + Default is `softmax`. + pretrained: one of `None` (random initialization) or 'imagenet' (pre-training on ImageNet). + Will try to download and load pre-trained model weights if not None. + **kwargs: other parameters if available. + +Returns: + A `keras.Model` instance. +""" + +ResNeXt.__doc__ = __head_doc__ + """ +Args: + num_blocks: number of blocks in each stack. + model_name: string, model name. +""" + __tail_doc__ + """ +Model architectures: + | Model | Params | Image resolution | Top1 Acc | + | -------------- | ------ | ----------------- | -------- | + | resnext50 | 25M | 224 | 77.8 | + | resnext101 | 42M | 224 | 80.9 | +""" + +ResNeXt50.__doc__ = __head_doc__ + """ +Args: +""" + __tail_doc__ + +ResNeXt101.__doc__ = ResNeXt50.__doc__ + +groups_depthwise.__doc__ = __head_doc__ + """ +Grouped depthwise. Callable function, NOT defined as a layer. + +Args: + inputs: input tensor. + groups: number of groups splitted for `DepthwiseConv2D` result. + kernel_size: kernel size for `DepthwiseConv2D`. + strides: strides for `DepthwiseConv2D`. + padding: padding for `DepthwiseConv2D`. + +Examples: + +>>> from keras_cv_attention_models import attention_layers +>>> inputs = keras.layers.Input([28, 28, 192]) +>>> nn = attention_layers.groups_depthwise(inputs, groups=32) +>>> dd = keras.models.Model(inputs, nn) +>>> dd.output_shape +(None, 28, 28, 192) + +>>> dd.summary() +_________________________________________________________________ +Layer (type) Output Shape Param # +================================================================= +input_2 (InputLayer) [(None, 28, 28, 192)] 0 +_________________________________________________________________ +zero_padding2d (ZeroPadding2 (None, 30, 30, 192) 0 +_________________________________________________________________ +depthwise_conv2d (DepthwiseC (None, 28, 28, 1152) 10368 +_________________________________________________________________ +reshape (Reshape) (None, 28, 28, 32, 6, 6) 0 +_________________________________________________________________ +tf.math.reduce_sum (TFOpLamb (None, 28, 28, 32, 6) 0 +_________________________________________________________________ +reshape_1 (Reshape) (None, 28, 28, 192) 0 +================================================================= +Total params: 10,368 +Trainable params: 10,368 +Non-trainable params: 0 +_________________________________________________________________ +""" diff --git a/keras_cv_attention_models/resnet_family/resnet_deep.py b/keras_cv_attention_models/resnet_family/resnet_deep.py new file mode 100644 index 00000000..abe93afc --- /dev/null +++ b/keras_cv_attention_models/resnet_family/resnet_deep.py @@ -0,0 +1,49 @@ +from keras_cv_attention_models.aotnet import AotNet +import os + +def ResNetD(num_blocks, input_shape=(224, 224, 3), pretrained="imagenet", deep_stem=True, stem_width=32, strides=2, **kwargs): + strides = strides if isinstance(strides, (list, tuple)) else [1, 2, 2, strides] + model = AotNet(num_blocks, input_shape=input_shape, deep_stem=deep_stem, stem_width=stem_width, strides=strides, **kwargs) + reload_model_weights(model, input_shape, pretrained) + return model + + +def reload_model_weights(model, input_shape=(224, 224, 3), pretrained="imagenet"): + pretrained_dd = { + "resnet50d": ["imagenet"], + } + if model.name not in pretrained_dd or pretrained not in pretrained_dd[model.name]: + print(">>>> No pretraind available, model will be randomly initialized") + return + + pre_url = "https://github.com/leondgarse/keras_cv_attention_models/releases/download/resnet_family/{}_{}.h5" + url = pre_url.format(model.name, pretrained) + file_name = os.path.basename(url) + try: + pretrained_model = keras.utils.get_file(file_name, url, cache_subdir="models") + except: + print("[Error] will not load weights, url not found or download failed:", url) + return + else: + print(">>>> Load pretraind from:", pretrained_model) + model.load_weights(pretrained_model, by_name=True, skip_mismatch=True) + + +def ResNet50D(input_shape=(224, 224, 3), num_classes=1000, activation="relu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + num_blocks = [3, 4, 6, 3] + return ResNetD(**locals(), model_name="resnet50d", **kwargs) + + +def ResNet101D(input_shape=(224, 224, 3), num_classes=1000, activation="relu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + num_blocks = [3, 4, 23, 3] + return ResNetD(**locals(), model_name="resnet101d", **kwargs) + + +def ResNet152D(input_shape=(224, 224, 3), num_classes=1000, activation="relu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + num_blocks = [3, 8, 36, 3] + return ResNetD(**locals(), model_name="resnet152d", **kwargs) + + +def ResNet200D(input_shape=(224, 224, 3), num_classes=1000, activation="relu", classifier_activation="softmax", pretrained="imagenet", **kwargs): + num_blocks = [3, 24, 36, 3] + return ResNetD(**locals(), model_name="resnet200d", **kwargs) diff --git a/keras_cv_attention_models/resnet_family/resnet_quad.py b/keras_cv_attention_models/resnet_family/resnet_quad.py new file mode 100644 index 00000000..3cc7e406 --- /dev/null +++ b/keras_cv_attention_models/resnet_family/resnet_quad.py @@ -0,0 +1,218 @@ +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import backend as K +import os + +BATCH_NORM_DECAY = 0.9 +BATCH_NORM_EPSILON = 1e-5 +CONV_KERNEL_INITIALIZER = tf.keras.initializers.VarianceScaling(scale=2.0, mode="fan_out", distribution="truncated_normal") + + +def batchnorm_with_activation(inputs, activation="relu", zero_gamma=False, name=None): + """Performs a batch normalization followed by an activation. """ + bn_axis = 3 if K.image_data_format() == "channels_last" else 1 + gamma_initializer = tf.zeros_initializer() if zero_gamma else tf.ones_initializer() + nn = keras.layers.BatchNormalization( + axis=bn_axis, + momentum=BATCH_NORM_DECAY, + epsilon=BATCH_NORM_EPSILON, + gamma_initializer=gamma_initializer, + name=name and name + "bn", + )(inputs) + if activation: + nn = keras.layers.Activation(activation=activation, name=name and name + activation)(nn) + return nn + + +def conv2d_no_bias(inputs, filters, kernel_size, strides=1, padding="VALID", use_bias=False, groups=1, name=None, **kwargs): + pad = max(kernel_size) // 2 if isinstance(kernel_size, (list, tuple)) else kernel_size // 2 + if padding.upper() == "SAME" and pad != 0: + inputs = keras.layers.ZeroPadding2D(padding=pad, name=name and name + "pad")(inputs) + + groups = groups if groups != 0 else 1 + if groups == filters: + return keras.layers.DepthwiseConv2D( + kernel_size, + strides=strides, + padding="VALID", + use_bias=use_bias, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=name and name + "conv", + **kwargs + )(inputs) + else: + return keras.layers.Conv2D( + filters, + kernel_size, + strides=strides, + padding="VALID", + use_bias=use_bias, + groups=groups, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=name and name + "conv", + **kwargs, + )(inputs) + + +def drop_block(inputs, drop_rate=0, name=None): + if drop_rate > 0: + noise_shape = [None] + [1] * (len(inputs.shape) - 1) # [None, 1, 1, 1] + return keras.layers.Dropout(drop_rate, noise_shape=noise_shape, name=name and name + "drop")(inputs) + else: + return inputs + + +def quad_block(inputs, filters, groups_div=32, strides=1, conv_shortcut=False, expansion=4, extra_conv=False, drop_rate=0, activation="swish", name=""): + expanded_filter = filters * expansion + groups = filters // groups_div if groups_div != 0 else 1 + if conv_shortcut: + shortcut = conv2d_no_bias(inputs, expanded_filter, 1, strides=strides, name=name + "shortcut_") + shortcut = batchnorm_with_activation(shortcut, activation=None, zero_gamma=False, name=name + "shortcut_") + else: + shortcut = inputs + + if groups != 1: # Edge block + nn = conv2d_no_bias(inputs, filters, 1, strides=1, padding="VALID", name=name + "1_") + nn = batchnorm_with_activation(nn, activation=activation, zero_gamma=False, name=name + "1_") + else: + nn = inputs + + nn = conv2d_no_bias(nn, filters, 3, strides=strides, padding="SAME", groups=groups, name=name + "groups_") + nn = batchnorm_with_activation(nn, activation=activation, zero_gamma=False, name=name + "2_") + + if extra_conv: + nn = conv2d_no_bias(nn, filters, 3, strides=1, padding="SAME", groups=groups, name=name + "extra_groups_") + nn = batchnorm_with_activation(nn, activation=activation, zero_gamma=False, name=name + "extra_2_") + + nn = conv2d_no_bias(nn, expanded_filter, 1, strides=1, padding="VALID", name=name + "3_") + nn = batchnorm_with_activation(nn, activation=None, zero_gamma=True, name=name + "3_") + + # print(">>>> shortcut:", shortcut.shape, "nn:", nn.shape) + nn = drop_block(nn, drop_rate) + nn = keras.layers.Add(name=name + "add")([shortcut, nn]) + return keras.layers.Activation(activation, name=name + "out")(nn) + + +def quad_stack(inputs, blocks, filters, groups_div, strides=2, expansion=4, extra_conv=False, stack_drop=0, activation="swish", name=""): + nn = inputs + stack_drop_s, stack_drop_e = stack_drop if isinstance(stack_drop, (list, tuple)) else [stack_drop, stack_drop] + for id in range(blocks): + conv_shortcut = True if id == 0 and (strides != 1 or inputs.shape[-1] != filters * expansion) else False + cur_strides = strides if id == 0 else 1 + block_name = name + "block{}_".format(id + 1) + block_drop_rate = stack_drop_s + (stack_drop_e - stack_drop_s) * id / blocks + nn = quad_block(nn, filters, groups_div, cur_strides, conv_shortcut, expansion, extra_conv, block_drop_rate, activation, name=block_name) + return nn + + +def quad_stem(inputs, stem_width, activation="swish", stem_act=False, name=""): + nn = conv2d_no_bias(inputs, stem_width // 8, 3, strides=2, padding="same", name=name + "1_") + if stem_act: + nn = batchnorm_with_activation(nn, activation=activation, name=name + "1_") + nn = conv2d_no_bias(nn, stem_width // 4, 3, strides=1, padding="same", name=name + "2_") + if stem_act: + nn = batchnorm_with_activation(nn, activation=activation, name=name + "2_") + nn = conv2d_no_bias(nn, stem_width // 2, 3, strides=1, padding="same", name=name + "3_") + nn = batchnorm_with_activation(nn, activation=activation, name=name + "3_") + nn = conv2d_no_bias(nn, stem_width, 3, strides=2, padding="same", name=name + "4_") + return nn + + +def ResNetQ( + num_blocks, + strides=2, + out_channels=[64, 128, 384, 384], + stem_width=128, + stem_act=False, + stem_downsample=False, + expansion=4, + groups_div=32, + extra_conv=False, + num_features=2048, + input_shape=(224, 224, 3), + num_classes=1000, + activation="swish", + drop_connect_rate=0, + classifier_activation="softmax", + pretrained="imagenet", + model_name="resnetq", + kwargs=None +): + inputs = keras.layers.Input(shape=input_shape) + nn = quad_stem(inputs, stem_width, activation=activation, stem_act=stem_act, name="stem_") + nn = batchnorm_with_activation(nn, activation=activation, name="stem_") + if stem_downsample: + nn = keras.layers.ZeroPadding2D(padding=1, name="stem_pool_pad")(nn) + nn = keras.layers.MaxPooling2D(pool_size=3, strides=2, name="stem_pool")(nn) + + total_blocks = sum(num_blocks) + global_block_id = 0 + drop_connect_s, drop_connect_e = 0, drop_connect_rate + strides = strides if isinstance(strides, (list, tuple)) else [1, 2, 2, strides] + for id, (num_block, out_channel, stride) in enumerate(zip(num_blocks, out_channels, strides)): + name = "stack{}_".format(id + 1) + stack_drop_s = drop_connect_rate * global_block_id / total_blocks + stack_drop_e = drop_connect_rate * (global_block_id + num_block) / total_blocks + stack_drop = (stack_drop_s, stack_drop_e) + cur_expansion = expansion[id] if isinstance(expansion, (list, tuple)) else expansion + cur_extra_conv = extra_conv[id] if isinstance(extra_conv, (list, tuple)) else extra_conv + cur_groups_div = groups_div[id] if isinstance(groups_div, (list, tuple)) else groups_div + nn = quad_stack(nn, num_block, out_channel, cur_groups_div, stride, cur_expansion, cur_extra_conv, stack_drop, activation, name=name) + global_block_id += num_block + + if num_features != 0: # efficientnet like + nn = conv2d_no_bias(nn, num_features, 1, strides=1, name="features_") + nn = batchnorm_with_activation(nn, activation=activation, name="features_") + + if num_classes > 0: + nn = keras.layers.GlobalAveragePooling2D(name="avg_pool")(nn) + nn = keras.layers.Dense(num_classes, activation=classifier_activation, name="predictions")(nn) + + model = keras.models.Model(inputs, nn, name=model_name) + reload_model_weights(model, input_shape, pretrained) + return model + + +def reload_model_weights(model, input_shape=(224, 224, 3), pretrained="imagenet"): + pretrained_dd = { + "resnetq51": ["imagenet"], + } + if model.name not in pretrained_dd or pretrained not in pretrained_dd[model.name]: + print(">>>> No pretraind available, model will be randomly initialized") + return + + pre_url = "https://github.com/leondgarse/keras_cv_attention_models/releases/download/resnet_family/{}_{}.h5" + url = pre_url.format(model.name, pretrained) + file_name = os.path.basename(url) + try: + pretrained_model = keras.utils.get_file(file_name, url, cache_subdir="models") + except: + print("[Error] will not load weights, url not found or download failed:", url) + return + else: + print(">>>> Load pretraind from:", pretrained_model) + model.load_weights(pretrained_model, by_name=True, skip_mismatch=True) + + +def ResNet51Q(input_shape=(224, 224, 3), num_classes=1000, activation="swish", classifier_activation="softmax", pretrained="imagenet", **kwargs): + num_blocks = [2, 4, 6, 4] + out_channels = [64, 128, 384, 384 * 4] + stem_width = 128 + stem_act = False + expansion = [4, 4, 4, 1] + groups_div = [32, 32, 32, 1] + extra_conv = False + num_features = 2048 + return ResNetQ(**locals(), model_name="resnetq51", **kwargs) + + +def ResNet61Q(input_shape=(224, 224, 3), num_classes=1000, activation="swish", classifier_activation="softmax", pretrained="imagenet", **kwargs): + num_blocks = [1, 4, 6, 4] + out_channels = [256, 128, 384, 384 * 4] + stem_width = 128 + stem_act = True + expansion = [1, 4, 4, 1] + groups_div = [0, 32, 32, 1] + extra_conv = [False, True, True, True] + num_features = 2048 + return ResNetQ(**locals(), model_name="resnetq61", **kwargs) diff --git a/keras_cv_attention_models/resnext/resnext.py b/keras_cv_attention_models/resnet_family/resnext.py similarity index 96% rename from keras_cv_attention_models/resnext/resnext.py rename to keras_cv_attention_models/resnet_family/resnext.py index 626e796d..71a4d592 100644 --- a/keras_cv_attention_models/resnext/resnext.py +++ b/keras_cv_attention_models/resnet_family/resnext.py @@ -48,7 +48,7 @@ def groups_depthwise(inputs, groups=32, kernel_size=3, strides=1, padding="SAME" nn = keras.layers.ZeroPadding2D(padding=kernel_size // 2, name=name and name + "pad")(nn) nn = keras.layers.DepthwiseConv2D(kernel_size, strides=strides, depth_multiplier=cc, use_bias=False, name=name and name + "DC")(nn) nn = keras.layers.Reshape((*nn.shape[1:-1], groups, cc, cc))(nn) - nn = tf.reduce_sum(nn, axis=-2) + nn = tf.reduce_sum(nn, axis=-1) nn = keras.layers.Reshape((*nn.shape[1:-2], input_filter))(nn) return nn @@ -64,7 +64,8 @@ def block(inputs, filters, strides=1, conv_shortcut=False, cardinality=2, activa nn = conv2d_no_bias(inputs, filters, 1, strides=1, padding="VALID", name=name + "1_") nn = batchnorm_with_activation(nn, activation=activation, zero_gamma=False, name=name + "1_") - nn = groups_depthwise(nn, groups=64 // cardinality, kernel_size=3, strides=strides, name=name + "GD_") + # nn = groups_depthwise(nn, groups=64 // cardinality, kernel_size=3, strides=strides, name=name + "GD_") + nn = conv2d_no_bias(nn, nn.shape[-1], 3, strides=strides, groups=64 // cardinality, name=name + "GC_") nn = batchnorm_with_activation(nn, activation=activation, zero_gamma=False, name=name) nn = conv2d_no_bias(nn, expanded_filter, 1, strides=1, padding="VALID", name=name + "3_")