-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
744 lines (677 loc) · 25 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
import torch
import copy
import torch.nn as nn
from typing import List, Tuple
class RepBlock(nn.Module):
"""
MobileOne-style residual blocks, including residual joins and re-parameterization convolutions
"""
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 1,
groups: int = 1,
inference_mode: bool = False,
rbr_conv_kernel_list: List[int] = [7, 3],
use_bn_conv: bool = False,
act_layer: nn.Module = nn.ReLU,
skip_include_bn: bool = True,
) -> None:
"""Construct a Re-parameterization module.
:param in_channels: Number of input channels.
:param out_channels: Number of output channels.
:param stride: Stride for convolution.
:param groups: Number of groups for convolution.
:param inference_mode: Whether to use inference mode.
:param rbr_conv_kernel_list: List of kernel sizes for re-parameterizable convolutions.
:param use_bn_conv: Whether the bn is in front of conv, if false, conv is in front of bn
:param act_layer: Activation layer.
:param skip_include_bn: Whether to include bn in skip connection.
"""
super(RepBlock, self).__init__()
self.inference_mode = inference_mode
self.groups = groups
self.stride = stride
self.in_channels = in_channels
self.out_channels = out_channels
self.rbr_conv_kernel_list = sorted(rbr_conv_kernel_list, reverse=True)
self.num_conv_branches = len(self.rbr_conv_kernel_list)
self.kernel_size = self.rbr_conv_kernel_list[0]
self.use_bn_conv = use_bn_conv
self.skip_include_bn = skip_include_bn
self.activation = act_layer()
if inference_mode:
self.reparam_conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.kernel_size,
stride=stride,
padding=self.kernel_size // 2,
groups=groups,
bias=True,
)
else:
# Re-parameterizable skip connection
if out_channels == in_channels and stride == 1:
if self.skip_include_bn:
# Use residual connections that include BN
self.rbr_skip = nn.BatchNorm2d(num_features=in_channels)
else:
# Use residual connections
self.rbr_skip = nn.Identity()
else:
# Use residual connections
self.rbr_skip = None
# Re-parameterizable conv branches
rbr_conv = list()
for kernel_size in self.rbr_conv_kernel_list:
if self.use_bn_conv:
rbr_conv.append(
self._bn_conv(
in_chans=in_channels,
out_chans=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
groups=groups,
)
)
else:
rbr_conv.append(
self._conv_bn(
in_chans=in_channels,
out_chans=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
groups=groups,
)
)
self.rbr_conv = nn.ModuleList(rbr_conv)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward pass."""
# Inference mode forward pass.
if self.inference_mode:
return self.activation(self.reparam_conv(x))
# Multi-branched train-time forward pass.
# Skip branch output
identity_out = 0
if self.rbr_skip is not None:
identity_out = self.rbr_skip(x)
# Other branches
out = identity_out
for ix in range(self.num_conv_branches):
out = out + self.rbr_conv[ix](x)
return self.activation(out)
def reparameterize(self):
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
architecture used at training time to obtain a plain CNN-like structure
for inference.
"""
if self.inference_mode:
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = nn.Conv2d(
in_channels=self.rbr_conv[0].conv.in_channels,
out_channels=self.rbr_conv[0].conv.out_channels,
kernel_size=self.rbr_conv[0].conv.kernel_size,
stride=self.rbr_conv[0].conv.stride,
padding=self.rbr_conv[0].conv.padding,
dilation=self.rbr_conv[0].conv.dilation,
groups=self.rbr_conv[0].conv.groups,
bias=True,
)
self.reparam_conv.weight.data = kernel
self.reparam_conv.bias.data = bias
# Delete un-used branches
for para in self.parameters():
para.detach_()
self.__delattr__("rbr_conv")
if hasattr(self, "rbr_skip"):
self.__delattr__("rbr_skip")
self.inference_mode = True
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to obtain re-parameterized kernel and bias.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
:return: Tuple of (kernel, bias) after fusing branches.
"""
# get weights and bias of skip branch
kernel_identity = 0
bias_identity = 0
if self.rbr_skip is not None:
kernel_identity, bias_identity = self._fuse_skip_tensor(self.rbr_skip)
# get weights and bias of conv branches
kernel_conv = 0
bias_conv = 0
for ix in range(self.num_conv_branches):
if self.use_bn_conv:
_kernel, _bias = self._fuse_bn_conv_tensor(self.rbr_conv[ix])
else:
_kernel, _bias = self._fuse_conv_bn_tensor(self.rbr_conv[ix])
# pad kernel
if _kernel.shape[-1] < self.kernel_size:
pad = (self.kernel_size - _kernel.shape[-1]) // 2
_kernel = torch.nn.functional.pad(_kernel, [pad, pad, pad, pad])
kernel_conv += _kernel
bias_conv += _bias
kernel_final = kernel_conv + kernel_identity
bias_final = bias_conv + bias_identity
return kernel_final, bias_final
def _fuse_skip_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param branch: skip branch, maybe include bn layer
:return: Tuple of (kernel, bias) after fusing batchnorm.
"""
if not hasattr(self, "id_tensor"):
input_dim = self.in_channels // self.groups
kernel_value = torch.zeros(
(self.in_channels, input_dim, self.kernel_size, self.kernel_size),
dtype=self.rbr_conv[0].conv.weight.dtype,
device=self.rbr_conv[0].conv.weight.device,
)
for i in range(self.in_channels):
kernel_value[
i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
] = 1
self.id_tensor = kernel_value
if isinstance(branch, nn.Identity):
kernel = self.id_tensor
return kernel, torch.zeros(
(self.in_channels),
dtype=self.rbr_conv[0].conv.weight.dtype,
device=self.rbr_conv[0].conv.weight.device,
)
else:
assert isinstance(
branch, nn.BatchNorm2d
), "Make sure the module in skip is nn. BatchNorm2d"
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
def _fuse_bn_conv_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
"""先bn,后conv
:param branch:
:return: Tuple of (kernel, bias) after fusing batchnorm.
"""
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
std = (running_var + eps).sqrt()
t = gamma / std
t = torch.stack([t] * (kernel.shape[0] * kernel.shape[1]//t.shape[0]),dim=0).reshape(-1, self.in_channels // self.groups, 1, 1)
t_beta = torch.stack([beta] * (kernel.shape[0] * kernel.shape[1]//beta.shape[0]),dim=0).reshape(-1, self.in_channels // self.groups, 1, 1)
t_running_mean = torch.stack([running_mean] * (kernel.shape[0] * kernel.shape[1]//running_mean.shape[0]),dim=0).reshape(-1, self.in_channels // self.groups, 1, 1)
return kernel * t, torch.sum(
kernel
* (
t_beta - t_running_mean * t
),
dim=(1, 2, 3),
)
def _fuse_conv_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
"""First conv, then bn
:param branch:
:return: Tuple of (kernel, bias) after fusing batchnorm.
"""
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
def _conv_bn(
self,
in_chans: int,
out_chans: int,
kernel_size: int,
stride: int,
padding: int,
groups: int,
) -> nn.Sequential:
"""First conv, then bn
:param kernel_size: Size of the convolution kernel.
:param padding: Zero-padding size.
:return: Conv-BN module.
"""
mod_list = nn.Sequential()
mod_list.add_module(
"conv",
nn.Conv2d(
in_channels=in_chans,
out_channels=out_chans,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False,
),
)
mod_list.add_module("bn", nn.BatchNorm2d(num_features=out_chans))
return mod_list
def _bn_conv(
self,
in_chans: int,
out_chans: int,
kernel_size: int,
stride: int,
padding: int,
groups: int,
) -> nn.Sequential:
"""Add bn first, then conv"""
mod_list = nn.Sequential()
mod_list.add_module("bn", nn.BatchNorm2d(num_features=in_chans))
mod_list.add_module(
"conv",
nn.Conv2d(
in_channels=in_chans,
out_channels=out_chans,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False,
),
)
return mod_list
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
attn_head_dim=None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.dim = dim
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttnTokenMixer(nn.Module):
def __init__(
self,
in_chans: int,
num_heads: int,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
attn_head_dim=None,
inference_mode: bool = False,
use_CPE: bool = False,
) -> None:
super().__init__()
self.norm = nn.LayerNorm(in_chans)
self.attn = Attention(
in_chans,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop,
attn_head_dim=attn_head_dim,
)
self.norm = nn.LayerNorm(in_chans)
if use_CPE:
self.CPE = RepBlock(
in_channels=in_chans,
out_channels=in_chans,
rbr_conv_kernel_list=[3],
stride=1,
groups=in_chans,
inference_mode=inference_mode,
act_layer=nn.Identity,
skip_include_bn=False,
)
else:
self.CPE = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, Hp, Wp = x.shape
x = self.CPE(x)
x = x.flatten(2).transpose(1, 2)
x = self.attn(self.norm(x))
x = x.permute(0, 2, 1).reshape(B, C, Hp, Wp).contiguous()
return x
class FastViTBlock(nn.Module):
def __init__(
self,
chans: int,
num_heads: int,
inference_mode: bool,
act_layer: nn.Module,
use_attn: bool,
expand_ratio: int,
use_CPE: bool = False,
) -> None:
super().__init__()
if use_attn:
# AttnTokenMixer
self.token_mixer = AttnTokenMixer(
chans, num_heads, inference_mode=inference_mode, use_CPE=use_CPE
)
else:
# RepMixer
self.token_mixer = RepBlock(
in_channels=chans,
out_channels=chans,
stride=1,
groups=chans,
rbr_conv_kernel_list=[3],
use_bn_conv=True,
act_layer=act_layer,
inference_mode=inference_mode,
skip_include_bn=False,
)
mid_chans = chans * expand_ratio
# conv_ffn
self.conv_ffn = nn.Sequential(
nn.Conv2d(chans, chans, kernel_size=7, padding=3, groups=chans, bias=False),
nn.BatchNorm2d(chans),
nn.Conv2d(chans, mid_chans, kernel_size=1, padding=0),
act_layer(),
nn.Conv2d(mid_chans, chans, kernel_size=1, padding=0),
)
def forward(self, x):
x = self.token_mixer(x)
x = x + self.conv_ffn(x)
return x
class FastStage(nn.Module):
def __init__(
self,
in_chans: int,
out_chans: int,
use_attn: bool,
num_heads: int,
num_blocks_per_stage: int,
inference_mode: bool,
act_layer: nn.Module,
expand_ratio: int,
use_patch_embed: bool,
) -> None:
"""
Constructs a FastStage
:param in_chans: Number of input channels.
:param out_chans: Number of output channels.
:param num_heads: Number of heads for attention If use_attn is True.
:param use_attn: Whether to use attention.
:param num_blocks_per_stage: Number of blocks per stage.
:param inference_mode: Whether to use inference mode.
:param act_layer: Activation layer.
:param expand_ratio: Expansion ratio in conv_ffn.
:param use_patch_embed: Whether to use patch embedding.
"""
super().__init__()
self.num_blocks_per_stage = num_blocks_per_stage
if use_patch_embed:
self.patch_embed = nn.Sequential(
RepBlock(
in_channels=in_chans,
out_channels=in_chans,
rbr_conv_kernel_list=[7, 3],
stride=2,
groups=in_chans,
inference_mode=inference_mode,
act_layer=act_layer,
),
RepBlock(
in_channels=in_chans,
out_channels=out_chans,
rbr_conv_kernel_list=[1],
stride=1,
groups=1,
inference_mode=inference_mode,
act_layer=act_layer,
),
)
else:
self.patch_embed = nn.Identity()
# FastViTBlock, CPE is only used on the first block
self.blocks = nn.Sequential(
*[
FastViTBlock(
chans=out_chans,
num_heads=num_heads,
inference_mode=inference_mode,
act_layer=act_layer,
use_attn=use_attn,
expand_ratio=expand_ratio,
use_CPE=(i == 0),
)
for i in range(num_blocks_per_stage)
]
)
def forward(self, x):
x = self.patch_embed(x)
x = self.blocks(x)
return x
class FastVit(nn.Module):
def __init__(
self,
num_classes: int = 1000,
inference_mode: bool = False,
in_chans_list: Tuple[int] = (48, 48, 96, 192),
out_chans_list: Tuple[int] = (48, 96, 192, 384),
blocks_per_stage: Tuple[int] = (2, 2, 4, 2),
expand_ratio: Tuple[int] = (4, 4, 4, 4),
use_attn: Tuple[bool] = (False, False, False, False),
use_patchEmb: Tuple[bool] = (False, True, True, True),
act_layer: nn.Module = nn.ReLU,
) -> None:
"""
Constructs a FastVit model
:param num_classes: Number of classes for classification head.
:param inference_mode: Whether to use inference mode.
:param in_chans_list: List of input channels for each stage.
:param out_chans_list: List of output channels for each stage.
:param blocks_per_stage: List of number of blocks for each stage.
:param expand_ratio: List of expansion ratios for each stage.
:param use_attn: List of whether to use attention for each stage.
:param use_patchEmb: List of whether to use patch embedding for each stage.
:param act_layer: Activation layer.
"""
super().__init__()
self.stem = nn.Sequential(
RepBlock(
in_channels=3,
out_channels=in_chans_list[0],
rbr_conv_kernel_list=[3, 1],
stride=2,
groups=1,
inference_mode=inference_mode,
act_layer=act_layer,
),
RepBlock(
in_channels=in_chans_list[0],
out_channels=in_chans_list[0],
rbr_conv_kernel_list=[3, 1],
stride=2,
groups=in_chans_list[0],
inference_mode=inference_mode,
act_layer=act_layer,
),
RepBlock(
in_channels=in_chans_list[0],
out_channels=in_chans_list[0],
rbr_conv_kernel_list=[1],
stride=1,
groups=1,
inference_mode=inference_mode,
act_layer=act_layer,
skip_include_bn=True,
),
)
self.stages = nn.Sequential(
*(
FastStage(
in_chans=in_chans_list[i],
out_chans=out_chans_list[i],
num_blocks_per_stage=blocks_per_stage[i],
inference_mode=inference_mode,
use_attn=use_attn[i],
num_heads=8,
expand_ratio=expand_ratio[i],
act_layer=nn.ReLU,
use_patch_embed=use_patchEmb[i],
)
for i in range(len(blocks_per_stage))
)
)
self.last_block = RepBlock(
in_channels=out_chans_list[-1],
out_channels=out_chans_list[-1],
stride=1,
groups=out_chans_list[-1],
inference_mode=inference_mode,
rbr_conv_kernel_list=[3, 1],
act_layer=act_layer,
)
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
nn.Linear(in_features=out_chans_list[-1], out_features=num_classes),
)
def forward(self, x) -> torch.Tensor:
x = self.stem(x)
x = self.stages(x)
x = self.last_block(x)
x = self.head(x)
return x
def fast_vit(
num_classes: int = 1000, inference_mode: bool = False, variant: str = "T8"
) -> nn.Module:
"""
Constructs a FastVit model
:param num_classes: Number of classes for classification head.
:param inference_mode: Whether to use inference mode.
:param variant: Variant of FastVit.
"""
PARAMS = {
"T8": {
"in_chans_list": (48, 48, 96, 192),
"out_chans_list": (48, 96, 192, 384),
"blocks_per_stage": (2, 2, 4, 2),
"expand_ratio": (3, 3, 3, 3),
"use_attn": (False, False, False, False),
"use_patchEmb": (False, True, True, True),
},
"T12": {
"in_chans_list": (64, 64, 128, 256),
"out_chans_list": (64, 128, 256, 512),
"blocks_per_stage": (2, 2, 6, 2),
"expand_ratio": (3, 3, 3, 3),
"use_attn": (False, False, False, False),
"use_patchEmb": (False, True, True, True),
},
"S12": {
"in_chans_list": (64, 64, 128, 256),
"out_chans_list": (64, 128, 256, 512),
"blocks_per_stage": (2, 2, 6, 2),
"expand_ratio": (4, 4, 4, 4),
"use_attn": (False, False, False, True),
"use_patchEmb": (False, True, True, True),
},
"SA12": {
"in_chans_list": (64, 64, 128, 256),
"out_chans_list": (64, 128, 256, 512),
"blocks_per_stage": (2, 2, 6, 2),
"expand_ratio": (4, 4, 4, 4),
"use_attn": (False, False, False, True),
"use_patchEmb": (False, True, True, True),
},
"SA24": {
"in_chans_list": (64, 64, 128, 256),
"out_chans_list": (64, 128, 256, 512),
"blocks_per_stage": (4, 4, 12, 4),
"expand_ratio": (4, 4, 4, 4),
"use_attn": (False, False, False, True),
"use_patchEmb": (False, True, True, True),
},
"SA36": {
"in_chans_list": (64, 64, 128, 256),
"out_chans_list": (64, 128, 256, 512),
"blocks_per_stage": (6, 6, 18, 6),
"expand_ratio": (4, 4, 4, 4),
"use_attn": (False, False, False, True),
"use_patchEmb": (False, True, True, True),
},
"MA36": {
"in_chans_list": (76, 76, 152, 304),
"out_chans_list": (76, 152, 304, 608),
"blocks_per_stage": (6, 6, 18, 6),
"expand_ratio": (4, 4, 4, 4),
"use_attn": (False, False, False, True),
"use_patchEmb": (False, True, True, True),
},
}
if variant not in PARAMS:
raise ValueError(
"Invalid variant: {},valiable keys are {}".format(variant, PARAMS.keys())
)
variant_params = PARAMS[variant]
return FastVit(
num_classes=num_classes, inference_mode=inference_mode, **variant_params
)
def reparameterize_model(model: torch.nn.Module) -> nn.Module:
"""
Method returns a model where a multi-branched structure
used in training is re-parameterized into a single branch
for inference.
:param model: MobileOne model in train mode.
:return: MobileOne model in inference mode.
"""
# Avoid editing original graph
model = copy.deepcopy(model)
for module in model.modules():
if hasattr(module, "reparameterize"):
module.reparameterize()
return model
if __name__ == "__main__":
from fvcore.nn import FlopCountAnalysis, flop_count_table
inp = torch.randn(1, 3, 256, 256)
model = fast_vit(variant="T8")
flops = FlopCountAnalysis(model, inp)
print(flop_count_table(flops, max_depth=5))
res = model(inp)
model = reparameterize_model(model)
flops = FlopCountAnalysis(model, inp)
print(flop_count_table(flops, max_depth=5))