Skip to content

Commit

Permalink
Two small fixes, num_classes in base class, add model tag
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jul 26, 2023
1 parent 3318e76 commit b71d60c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions timm/models/repvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def __init__(self, dim, num_classes, distillation=False):
self.distillation = distillation
if distillation:
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
self.num_classes = num_classes

def forward(self, x):
if self.distillation:
Expand Down Expand Up @@ -248,6 +247,7 @@ def __init__(
self.grad_checkpointing = False
self.global_pool = global_pool
self.embed_dim = embed_dim
self.num_classes = num_classes

in_dim = embed_dim[0]
self.stem = RepViTStem(in_chans, in_dim, act_layer)
Expand Down Expand Up @@ -356,13 +356,13 @@ def _cfg(url='', **kwargs):

default_cfgs = generate_default_cfgs(
{
'repvit_m1': _cfg(
'repvit_m1.dist_in1k': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth'
),
'repvit_m2': _cfg(
'repvit_m2.dist_in1k': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth'
),
'repvit_m3': _cfg(
'repvit_m3.dist_in1k': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth'
),
}
Expand Down

0 comments on commit b71d60c

Please sign in to comment.