Skip to content

Commit

Permalink
Adding some more tiny test models to train
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Sep 6, 2024
1 parent ee5b1e8 commit 6ab2af6
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 6 deletions.
34 changes: 34 additions & 0 deletions timm/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,17 @@ def _cfgv2(url='', **kwargs):
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),

"test_convnext.r160_in1k": _cfg(
# hf_hub_id='timm/',
input_size=(3, 160, 160), pool_size=(5, 5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
"test_convnext2.r160_in1k": _cfg(
# hf_hub_id='timm/',
input_size=(3, 160, 160), pool_size=(5, 5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
"test_convnext3.r160_in1k": _cfg(
# hf_hub_id='timm/',
input_size=(3, 160, 160), pool_size=(5, 5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),

})


Expand Down Expand Up @@ -1146,6 +1157,29 @@ def convnextv2_huge(pretrained=False, **kwargs) -> ConvNeXt:
return model


@register_model
def test_convnext(pretrained=False, **kwargs) -> ConvNeXt:
model_args = dict(depths=[1, 2, 4, 2], dims=[24, 32, 48, 64], norm_eps=kwargs.pop('norm_eps', 1e-5), act_layer='gelu_tanh')
model = _create_convnext('test_convnext', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def test_convnext2(pretrained=False, **kwargs) -> ConvNeXt:
model_args = dict(depths=[1, 1, 1, 1], dims=[32, 64, 96, 128], norm_eps=kwargs.pop('norm_eps', 1e-5), act_layer='gelu_tanh')
model = _create_convnext('test_convnext2', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def test_convnext3(pretrained=False, **kwargs) -> ConvNeXt:
model_args = dict(
depths=[1, 1, 1, 1], dims=[32, 64, 96, 128], norm_eps=kwargs.pop('norm_eps', 1e-5), kernel_sizes=(7, 5, 5, 3), act_layer='silu')
model = _create_convnext('test_convnext3', pretrained=pretrained, **dict(model_args, **kwargs))
return model



register_model_deprecations(__name__, {
'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
Expand Down
10 changes: 10 additions & 0 deletions timm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,6 +1804,10 @@ def _cfg(url='', **kwargs):
"test_efficientnet.r160_in1k": _cfg(
hf_hub_id='timm/',
input_size=(3, 160, 160), pool_size=(5, 5)),
"test_efficientnet_gn.r160_in1k": _cfg(
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 160, 160), pool_size=(5, 5)),
})


Expand Down Expand Up @@ -2792,6 +2796,12 @@ def test_efficientnet(pretrained=False, **kwargs) -> EfficientNet:
return model


@register_model
def test_efficientnet_gn(pretrained=False, **kwargs) -> EfficientNet:
model = _gen_test_efficientnet(
'test_efficientnet_gn', pretrained=pretrained, norm_layer=partial(GroupNormAct, group_size=8), **kwargs)
return model

register_model_deprecations(__name__, {
'tf_efficientnet_b0_ap': 'tf_efficientnet_b0.ap_in1k',
'tf_efficientnet_b1_ap': 'tf_efficientnet_b1.ap_in1k',
Expand Down
14 changes: 14 additions & 0 deletions timm/models/nfnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,10 @@ def _dm_nfnet_cfg(
nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()),

test_nfnet=_nfnet_cfg(
depths=(1, 1, 1, 1), channels=(32, 64, 96, 128), feat_mult=1.5, group_size=8, bottle_ratio=0.25,
attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'),
)


Expand Down Expand Up @@ -730,6 +734,11 @@ def _dcfg(url='', **kwargs):
'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'),
'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'),
'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'),

'test_nfnet.r160_in1k': _dcfg(
# hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
crop_pct=0.875, input_size=(3, 160, 160), pool_size=(5, 5)),
})


Expand Down Expand Up @@ -1029,3 +1038,8 @@ def nf_ecaresnet101(pretrained=False, **kwargs) -> NormFreeNet:
""" Normalization-Free ECA-ResNet101
"""
return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs)


@register_model
def test_nfnet(pretrained=False, **kwargs) -> NormFreeNet:
return _create_normfreenet('test_nfnet', pretrained=pretrained, **kwargs)
28 changes: 22 additions & 6 deletions timm/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import torch.nn.functional as F

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
get_attn, get_act_layer, get_norm_layer, create_classifier, create_aa
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, LayerType, create_attn, \
get_attn, get_act_layer, get_norm_layer, create_classifier, create_aa, to_ntuple
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
Expand Down Expand Up @@ -286,7 +286,7 @@ def drop_blocks(drop_prob: float = 0.):


def make_blocks(
block_fn: Union[BasicBlock, Bottleneck],
block_fns: Tuple[Union[BasicBlock, Bottleneck]],
channels: Tuple[int, ...],
block_repeats: Tuple[int, ...],
inplanes: int,
Expand All @@ -304,7 +304,7 @@ def make_blocks(
net_block_idx = 0
net_stride = 4
dilation = prev_dilation = 1
for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))):
for stage_idx, (block_fn, planes, num_blocks, db) in enumerate(zip(block_fns, channels, block_repeats, drop_blocks(drop_block_rate))):
stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
stride = 1 if stage_idx == 0 else 2
if net_stride >= output_stride:
Expand Down Expand Up @@ -490,8 +490,9 @@ def __init__(
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

# Feature Blocks
block_fns = to_ntuple(len(channels))(block)
stage_modules, stage_feature_info = make_blocks(
block,
block_fns,
channels,
layers,
inplanes,
Expand All @@ -513,7 +514,7 @@ def __init__(
self.feature_info.extend(stage_feature_info)

# Head (Pooling and Classifier)
self.num_features = self.head_hidden_size = channels[-1] * block.expansion
self.num_features = self.head_hidden_size = channels[-1] * block_fns[-1].expansion
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

self.init_weights(zero_init_last=zero_init_last)
Expand Down Expand Up @@ -1301,6 +1302,11 @@ def _gcfg(url='', **kwargs):
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth',
first_conv='conv1.0'),

'test_resnet.r160_in1k': _cfg(
#hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 160, 160), pool_size=(5, 5), first_conv='conv1.0'),
})


Expand Down Expand Up @@ -2040,6 +2046,16 @@ def resnetrs420(pretrained: bool = False, **kwargs) -> ResNet:
return _create_resnet('resnetrs420', pretrained, **dict(model_args, **kwargs))


@register_model
def test_resnet(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a tiny ResNet test model.
"""
model_args = dict(
block=[BasicBlock, BasicBlock, Bottleneck, BasicBlock], layers=(1, 1, 1, 1),
stem_width=16, stem_type='deep', avg_down=True, channels=(32, 48, 48, 96))
return _create_resnet('test_resnet', pretrained, **dict(model_args, **kwargs))


register_model_deprecations(__name__, {
'tv_resnet34': 'resnet34.tv_in1k',
'tv_resnet50': 'resnet50.tv_in1k',
Expand Down

0 comments on commit 6ab2af6

Please sign in to comment.