Skip to content

Commit

Permalink
add support for yolo2_efficientnet
Browse files Browse the repository at this point in the history
  • Loading branch information
david8862 committed Dec 9, 2019
1 parent 137c60a commit f4db716
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 2 deletions.
4 changes: 2 additions & 2 deletions common/backbones/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,8 @@ def preprocess_input(x):

if __name__ == '__main__':
input_tensor = Input(shape=(None, None, 3), name='image_input')
model = EfficientNetB1(include_top=False, input_shape=(416, 416, 3), weights='imagenet')
#model = EfficientNetB0(include_top=True, input_tensor=input_tensor, weights='imagenet')
#model = EfficientNetB0(include_top=False, input_shape=(416, 416, 3), weights='imagenet')
model = EfficientNetB0(include_top=True, input_tensor=input_tensor, weights='imagenet')
model.summary()

import numpy as np
Expand Down
12 changes: 12 additions & 0 deletions yolo2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from yolo2.models.yolo2_mobilenet import yolo2_mobilenet_body, yolo2lite_mobilenet_body, tiny_yolo2_mobilenet_body, tiny_yolo2lite_mobilenet_body
from yolo2.models.yolo2_mobilenetv2 import yolo2_mobilenetv2_body, yolo2lite_mobilenetv2_body, tiny_yolo2_mobilenetv2_body, tiny_yolo2lite_mobilenetv2_body
from yolo2.models.yolo2_xception import yolo2_xception_body, yolo2lite_xception_body
from yolo2.models.yolo2_efficientnet import yolo2_efficientnet_body, yolo2lite_efficientnet_body, tiny_yolo2_efficientnet_body, tiny_yolo2lite_efficientnet_body
from yolo2.loss import yolo2_loss
from yolo2.postprocess import batched_yolo2_postprocess

Expand All @@ -30,6 +31,12 @@
'yolo2_mobilenet_lite': [yolo2lite_mobilenet_body, 87, None],
'yolo2_mobilenetv2': [yolo2_mobilenetv2_body, 155, None],
'yolo2_mobilenetv2_lite': [yolo2lite_mobilenetv2_body, 155, None],

# NOTE: backbone_length is for EfficientNetB0
# if change to other efficientnet level, you need to modify it
'yolo2_efficientnet': [yolo2_efficientnet_body, 235, None],
'yolo2_efficientnet_lite': [yolo2lite_efficientnet_body, 235, None],

'yolo2_xception': [yolo2_xception_body, 132, None],
'yolo2_xception_lite': [yolo2lite_xception_body, 132, None],

Expand All @@ -38,6 +45,11 @@
'tiny_yolo2_mobilenet_lite': [tiny_yolo2lite_mobilenet_body, 87, None],
'tiny_yolo2_mobilenetv2': [tiny_yolo2_mobilenetv2_body, 155, None],
'tiny_yolo2_mobilenetv2_lite': [tiny_yolo2lite_mobilenetv2_body, 155, None],

# NOTE: backbone_length is for EfficientNetB0
# if change to other efficientnet level, you need to modify it
'tiny_yolo2_efficientnet': [tiny_yolo2_efficientnet_body, 235, None],
'tiny_yolo2_efficientnet_lite': [tiny_yolo2lite_efficientnet_body, 235, None],
}


Expand Down
249 changes: 249 additions & 0 deletions yolo2/models/yolo2_efficientnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""YOLO_v2 EfficientNet Model Defined in Keras."""

from tensorflow.keras.layers import MaxPooling2D, Lambda, Concatenate, GlobalAveragePooling2D, Softmax
from tensorflow.keras.models import Model

from common.backbones.efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7
from yolo2.models.layers import compose, DarknetConv2D, DarknetConv2D_BN_Leaky, Depthwise_Separable_Conv2D_BN_Leaky, bottleneck_block, bottleneck_x2_block, space_to_depth_x2, space_to_depth_x2_output_shape


def get_efficientnet_backbone_info(input_tensor, level=0):
"""Parse different level EfficientNet backbone feature map info for YOLOv2 head build."""
if level == 0:
# input: 416 x 416 x 3
# top_activation: 13 x 13 x 1280
# block6a_expand_activation(middle in block6a): 26 x 26 x 672
# block5c_add(end of block5c): 26 x 26 x 112
# block4a_expand_activation(middle in block4a): 52 x 52 x 240
# block3b_add(end of block3b): 52 x 52 x 40
efficientnet = EfficientNetB0(input_tensor=input_tensor, weights='imagenet', include_top=False)

f1_name = 'top_activation'
f1_channel_num = 1280
f2_name = 'block6a_expand_activation'
f2_channel_num = 672
f3_name = 'block4a_expand_activation'
f3_channel_num = 240

elif level == 1:
# input: 416 x 416 x 3
# top_activation: 13 x 13 x 1280
# block6a_expand_activation(middle in block6a): 26 x 26 x 672
# block5d_add(end of block5d): 26 x 26 x 112
# block4a_expand_activation(middle in block4a): 52 x 52 x 240
# block3c_add(end of block3c): 52 x 52 x 40
efficientnet = EfficientNetB1(input_tensor=input_tensor, weights='imagenet', include_top=False)

f1_name = 'top_activation'
f1_channel_num = 1280
f2_name = 'block6a_expand_activation'
f2_channel_num = 672
f3_name = 'block4a_expand_activation'
f3_channel_num = 240

elif level == 2:
# input: 416 x 416 x 3
# top_activation: 13 x 13 x 1408
# block6a_expand_activation(middle in block6a): 26 x 26 x 720
# block5d_add(end of block5d): 26 x 26 x 120
# block4a_expand_activation(middle in block4a): 52 x 52 x 288
# block3c_add(end of block3c): 52 x 52 x 48
efficientnet = EfficientNetB2(input_tensor=input_tensor, weights='imagenet', include_top=False)

f1_name = 'top_activation'
f1_channel_num = 1408
f2_name = 'block6a_expand_activation'
f2_channel_num = 720
f3_name = 'block4a_expand_activation'
f3_channel_num = 288

elif level == 3:
# input: 416 x 416 x 3
# top_activation: 13 x 13 x 1536
# block6a_expand_activation(middle in block6a): 26 x 26 x 816
# block5e_add(end of block5e): 26 x 26 x 136
# block4a_expand_activation(middle in block4a): 52 x 52 x 288
# block3c_add(end of block3c): 52 x 52 x 48
efficientnet = EfficientNetB3(input_tensor=input_tensor, weights='imagenet', include_top=False)

f1_name = 'top_activation'
f1_channel_num = 1536
f2_name = 'block6a_expand_activation'
f2_channel_num = 816
f3_name = 'block4a_expand_activation'
f3_channel_num = 288

elif level == 4:
# input: 416 x 416 x 3
# top_activation: 13 x 13 x 1792
# block6a_expand_activation(middle in block6a): 26 x 26 x 960
# block5f_add(end of block5f): 26 x 26 x 160
# block4a_expand_activation(middle in block4a): 52 x 52 x 336
# block3d_add(end of block3d): 52 x 52 x 56
efficientnet = EfficientNetB4(input_tensor=input_tensor, weights='imagenet', include_top=False)

f1_name = 'top_activation'
f1_channel_num = 1792
f2_name = 'block6a_expand_activation'
f2_channel_num = 960
f3_name = 'block4a_expand_activation'
f3_channel_num = 336

elif level == 5:
# input: 416 x 416 x 3
# top_activation: 13 x 13 x 2048
# block6a_expand_activation(middle in block6a): 26 x 26 x 1056
# block5g_add(end of block5g): 26 x 26 x 176
# block4a_expand_activation(middle in block4a): 52 x 52 x 384
# block3e_add(end of block3e): 52 x 52 x 64
efficientnet = EfficientNetB5(input_tensor=input_tensor, weights='imagenet', include_top=False)

f1_name = 'top_activation'
f1_channel_num = 2048
f2_name = 'block6a_expand_activation'
f2_channel_num = 1056
f3_name = 'block4a_expand_activation'
f3_channel_num = 384

elif level == 6:
# input: 416 x 416 x 3
# top_activation: 13 x 13 x 2304
# block6a_expand_activation(middle in block6a): 26 x 26 x 1200
# block5h_add(end of block5h): 26 x 26 x 200
# block4a_expand_activation(middle in block4a): 52 x 52 x 432
# block3f_add(end of block3f): 52 x 52 x 72
efficientnet = EfficientNetB6(input_tensor=input_tensor, weights='imagenet', include_top=False)

f1_name = 'top_activation'
f1_channel_num = 2304
f2_name = 'block6a_expand_activation'
f2_channel_num = 1200
f3_name = 'block4a_expand_activation'
f3_channel_num = 432

elif level == 7:
# input: 416 x 416 x 3
# top_activation: 13 x 13 x 2560
# block6a_expand_activation(middle in block6a): 26 x 26 x 1344
# block5j_add(end of block5j): 26 x 26 x 224
# block4a_expand_activation(middle in block4a): 52 x 52 x 480
# block3g_add(end of block3g): 52 x 52 x 80
efficientnet = EfficientNetB7(input_tensor=input_tensor, weights='imagenet', include_top=False)

f1_name = 'top_activation'
f1_channel_num = 2560
f2_name = 'block6a_expand_activation'
f2_channel_num = 1344
f3_name = 'block4a_expand_activation'
f3_channel_num = 480

else:
raise ValueError('Invalid efficientnet backbone type')

# f1 shape : 13 x 13 x f1_channel_num
# f2 shape : 26 x 26 x f2_channel_num
# f3 shape : 52 x 52 x f3_channel_num
feature_map_info = {'f1_name' : f1_name,
'f1_channel_num' : f1_channel_num,
'f2_name' : f2_name,
'f2_channel_num' : f2_channel_num,
'f3_name' : f3_name,
'f3_channel_num' : f3_channel_num,
}

return efficientnet, feature_map_info


def yolo2_efficientnet_body(inputs, num_anchors, num_classes, level=0):
'''
Create YOLO_v2 EfficientNet model CNN body in keras.
# Arguments
level: EfficientNet level number.
by default we use basic EfficientNetB0 as backbone
'''
efficientnet, feature_map_info = get_efficientnet_backbone_info(inputs, level=level)
f1_channel_num = feature_map_info['f1_channel_num']

conv_head1 = compose(
DarknetConv2D_BN_Leaky(f1_channel_num, (3, 3)),
DarknetConv2D_BN_Leaky(f1_channel_num, (3, 3)))(efficientnet.output)

f2 = efficientnet.get_layer('block6a_expand_activation').output

conv_head2 = DarknetConv2D_BN_Leaky(int(64*(f1_channel_num//1024)), (1, 1))(f2)
# TODO: Allow Keras Lambda to use func arguments for output_shape?
conv_head2_reshaped = Lambda(
space_to_depth_x2,
output_shape=space_to_depth_x2_output_shape,
name='space_to_depth')(conv_head2)

x = Concatenate()([conv_head2_reshaped, conv_head1])
x = DarknetConv2D_BN_Leaky(f1_channel_num, (3, 3))(x)
x = DarknetConv2D(num_anchors * (num_classes + 5), (1, 1), name='predict_conv')(x)
return Model(inputs, x)


def yolo2lite_efficientnet_body(inputs, num_anchors, num_classes, level=0):
'''
Create YOLO_v2 Lite EfficientNet model CNN body in keras.
# Arguments
level: EfficientNet level number.
by default we use basic EfficientNetB0 as backbone
'''
efficientnet, feature_map_info = get_efficientnet_backbone_info(inputs, level=level)
f1_channel_num = feature_map_info['f1_channel_num']

conv_head1 = compose(
Depthwise_Separable_Conv2D_BN_Leaky(f1_channel_num, (3, 3)),
Depthwise_Separable_Conv2D_BN_Leaky(f1_channel_num, (3, 3)))(efficientnet.output)

f2 = efficientnet.get_layer('block6a_expand_activation').output

conv_head2 = DarknetConv2D_BN_Leaky(int(64*(f1_channel_num//1024)), (1, 1))(f2)
# TODO: Allow Keras Lambda to use func arguments for output_shape?
conv_head2_reshaped = Lambda(
space_to_depth_x2,
output_shape=space_to_depth_x2_output_shape,
name='space_to_depth')(conv_head2)

x = Concatenate()([conv_head2_reshaped, conv_head1])
x = Depthwise_Separable_Conv2D_BN_Leaky(f1_channel_num, (3, 3))(x)
x = DarknetConv2D(num_anchors * (num_classes + 5), (1, 1), name='predict_conv')(x)
return Model(inputs, x)


def tiny_yolo2_efficientnet_body(inputs, num_anchors, num_classes, level=0):
'''
Create Tiny YOLO_v2 EfficientNet model CNN body in keras.
# Arguments
level: EfficientNet level number.
by default we use basic EfficientNetB0 as backbone
'''
efficientnet, feature_map_info = get_efficientnet_backbone_info(inputs, level=level)
f1_channel_num = feature_map_info['f1_channel_num']

y = compose(
DarknetConv2D_BN_Leaky(f1_channel_num, (3,3)),
DarknetConv2D(num_anchors*(num_classes+5), (1,1), name='predict_conv'))(efficientnet.output)

return Model(inputs, y)


def tiny_yolo2lite_efficientnet_body(inputs, num_anchors, num_classes, level=0):
'''
Create Tiny YOLO_v2 Lite EfficientNet model CNN body in keras.
# Arguments
level: EfficientNet level number.
by default we use basic EfficientNetB0 as backbone
'''
efficientnet, feature_map_info = get_efficientnet_backbone_info(inputs, level=level)
f1_channel_num = feature_map_info['f1_channel_num']

y = compose(
Depthwise_Separable_Conv2D_BN_Leaky(f1_channel_num, (3,3)),
DarknetConv2D(num_anchors*(num_classes+5), (1,1), name='predict_conv'))(efficientnet.output)

return Model(inputs, y)

3 changes: 3 additions & 0 deletions yolo3/models/yolo3_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def get_efficientnet_backbone_info(input_tensor, level=0):
else:
raise ValueError('Invalid efficientnet backbone type')

# f1 shape : 13 x 13 x f1_channel_num
# f2 shape : 26 x 26 x f2_channel_num
# f3 shape : 52 x 52 x f3_channel_num
feature_map_info = {'f1_name' : f1_name,
'f1_channel_num' : f1_channel_num,
'f2_name' : f2_name,
Expand Down

0 comments on commit f4db716

Please sign in to comment.