Skip to content

Commit

Permalink
Add missing TF layers (#792)
Browse files Browse the repository at this point in the history
Add following layers to tf.py:
- TFMP (MP)
- TFSPPCSPC (SPPCSPC)
- TFRepConv (RepConv)
  • Loading branch information
ibaiGorordo authored Sep 19, 2022
1 parent a6215c0 commit f2439f8
Showing 1 changed file with 62 additions and 5 deletions.
67 changes: 62 additions & 5 deletions models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import torch.nn as nn
from tensorflow import keras

from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
DWConvTranspose2d, Focus, autopad)
from models.common import (C3, MP, SPP, SPPF, SPPCSPC, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
RepConv, DWConvTranspose2d, Focus, autopad)
from models.experimental import MixConv2d, attempt_load
from models.yolo import Detect
from utils.activations import SiLU
Expand Down Expand Up @@ -86,6 +86,36 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
def call(self, inputs):
return self.act(self.bn(self.conv(inputs)))

class TFRepConv(keras.layers.Layer):

def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=True, w=None):
super().__init__()


self.groups = g
self.in_channels = c1
self.out_channels = c2

assert k == 3
assert autopad(k, p) == 1
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"

padding_11 = autopad(k, p) - k // 2

self.act = activations(w.act) if act else tf.identity
rbr_reparam = keras.layers.Conv2D(
filters=c2,
kernel_size=k,
strides=s,
padding='SAME' if s == 1 else 'VALID',
use_bias=True,
kernel_initializer=keras.initializers.Constant(w.rbr_reparam.weight.permute(2, 3, 1, 0).numpy()),
bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.rbr_reparam.bias.numpy()))
self.rbr_reparam = rbr_reparam if s == 1 else keras.Sequential([TFPad(autopad(k, p)), rbr_reparam])
self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity

def call(self, inputs):
return self.act(self.bn(self.rbr_reparam(inputs)))

class TFDWConv(keras.layers.Layer):
# Depthwise convolution
Expand Down Expand Up @@ -239,6 +269,14 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
def call(self, inputs):
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))

class TFMP(keras.layers.Layer):
# Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, k=2, w=None):
super().__init__()
self.m = keras.layers.MaxPooling2D(pool_size=2, strides=2, padding='VALID')

def call(self, inputs):
return self.m(inputs)

class TFSPP(keras.layers.Layer):
# Spatial pyramid pooling layer used in YOLOv3-SPP
Expand Down Expand Up @@ -269,6 +307,24 @@ def call(self, inputs):
y2 = self.m(y1)
return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))

class TFSPPCSPC(keras.layers.Layer):
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13), w=None):
super().__init__()
c_ = int(2 * c2 * e) # hidden channels
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
self.cv3 = TFConv(c_, c_, 3, 1, w=w.cv3)
self.cv4 = TFConv(c_, c_, 1, 1, w=w.cv4)
self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
self.cv5 = TFConv(4 * c_, c_, 1, 1, w=w.cv5)
self.cv6 = TFConv(c_, c_, 3, 1, w=w.cv6)
self.cv7 = TFConv(2 * c_, c2, 1, 1, w=w.cv7)

def call(self, inputs):
x1 = self.cv4(self.cv3(self.cv1(inputs)))
y1 = self.cv6(self.cv5(tf.concat([x1] + [m(x1) for m in self.m], 3)))
y2 = self.cv2(inputs)
return self.cv7(tf.concat((y1, y2), 3))

class TFDetect(keras.layers.Layer):
# TF YOLOv5 Detect layer
Expand Down Expand Up @@ -355,6 +411,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
m_str = m

m = eval(m) if isinstance(m, str) else m # eval strings
for j, a in enumerate(args):
try:
Expand All @@ -364,16 +421,16 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)

n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [
nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3x]:
nn.Conv2d, Conv, DWConv, RepConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, SPPCSPC, MixConv2d,
Focus, CrossConv, BottleneckCSP, C3, C3x]:
c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2

args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3, C3x]:
args.insert(2, n)
n = 1
elif m is nn.BatchNorm2d:
elif m in [nn.BatchNorm2d, MP]:
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
Expand Down

0 comments on commit f2439f8

Please sign in to comment.