From f2439f894c32a5ad09dd099aae0586ff01080301 Mon Sep 17 00:00:00 2001 From: Ibai Gorordo <43162939+ibaiGorordo@users.noreply.github.com> Date: Tue, 20 Sep 2022 06:31:41 +0900 Subject: [PATCH] Add missing TF layers (#792) Add following layers to tf.py: - TFMP (MP) - TFSPPCSPC (SPPCSPC) - TFRepConv (RepConv) --- models/tf.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 5 deletions(-) diff --git a/models/tf.py b/models/tf.py index b0d98cc2a3..2d85066056 100644 --- a/models/tf.py +++ b/models/tf.py @@ -27,8 +27,8 @@ import torch.nn as nn from tensorflow import keras -from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, - DWConvTranspose2d, Focus, autopad) +from models.common import (C3, MP, SPP, SPPF, SPPCSPC, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, + RepConv, DWConvTranspose2d, Focus, autopad) from models.experimental import MixConv2d, attempt_load from models.yolo import Detect from utils.activations import SiLU @@ -86,6 +86,36 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): def call(self, inputs): return self.act(self.bn(self.conv(inputs))) +class TFRepConv(keras.layers.Layer): + + def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=True, w=None): + super().__init__() + + + self.groups = g + self.in_channels = c1 + self.out_channels = c2 + + assert k == 3 + assert autopad(k, p) == 1 + assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument" + + padding_11 = autopad(k, p) - k // 2 + + self.act = activations(w.act) if act else tf.identity + rbr_reparam = keras.layers.Conv2D( + filters=c2, + kernel_size=k, + strides=s, + padding='SAME' if s == 1 else 'VALID', + use_bias=True, + kernel_initializer=keras.initializers.Constant(w.rbr_reparam.weight.permute(2, 3, 1, 0).numpy()), + bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.rbr_reparam.bias.numpy())) + self.rbr_reparam = rbr_reparam if s == 1 else keras.Sequential([TFPad(autopad(k, p)), rbr_reparam]) + self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity + + def call(self, inputs): + return self.act(self.bn(self.rbr_reparam(inputs))) class TFDWConv(keras.layers.Layer): # Depthwise convolution @@ -239,6 +269,14 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None): def call(self, inputs): return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3)) +class TFMP(keras.layers.Layer): + # Spatial pyramid pooling layer used in YOLOv3-SPP + def __init__(self, k=2, w=None): + super().__init__() + self.m = keras.layers.MaxPooling2D(pool_size=2, strides=2, padding='VALID') + + def call(self, inputs): + return self.m(inputs) class TFSPP(keras.layers.Layer): # Spatial pyramid pooling layer used in YOLOv3-SPP @@ -269,6 +307,24 @@ def call(self, inputs): y2 = self.m(y1) return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3)) +class TFSPPCSPC(keras.layers.Layer): + def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13), w=None): + super().__init__() + c_ = int(2 * c2 * e) # hidden channels + self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1) + self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2) + self.cv3 = TFConv(c_, c_, 3, 1, w=w.cv3) + self.cv4 = TFConv(c_, c_, 1, 1, w=w.cv4) + self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k] + self.cv5 = TFConv(4 * c_, c_, 1, 1, w=w.cv5) + self.cv6 = TFConv(c_, c_, 3, 1, w=w.cv6) + self.cv7 = TFConv(2 * c_, c2, 1, 1, w=w.cv7) + + def call(self, inputs): + x1 = self.cv4(self.cv3(self.cv1(inputs))) + y1 = self.cv6(self.cv5(tf.concat([x1] + [m(x1) for m in self.m], 3))) + y2 = self.cv2(inputs) + return self.cv7(tf.concat((y1, y2), 3)) class TFDetect(keras.layers.Layer): # TF YOLOv5 Detect layer @@ -355,6 +411,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3) layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args m_str = m + m = eval(m) if isinstance(m, str) else m # eval strings for j, a in enumerate(args): try: @@ -364,8 +421,8 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3) n = max(round(n * gd), 1) if n > 1 else n # depth gain if m in [ - nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv, - BottleneckCSP, C3, C3x]: + nn.Conv2d, Conv, DWConv, RepConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, SPPCSPC, MixConv2d, + Focus, CrossConv, BottleneckCSP, C3, C3x]: c1, c2 = ch[f], args[0] c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 @@ -373,7 +430,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3) if m in [BottleneckCSP, C3, C3x]: args.insert(2, n) n = 1 - elif m is nn.BatchNorm2d: + elif m in [nn.BatchNorm2d, MP]: args = [ch[f]] elif m is Concat: c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)