diff --git a/timm/models/convnext.py b/timm/models/convnext.py index b3f3350e9..612f8acdc 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -951,6 +951,17 @@ def _cfgv2(url='', **kwargs): hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024), + + "test_convnext.r160_in1k": _cfg( + # hf_hub_id='timm/', + input_size=(3, 160, 160), pool_size=(5, 5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + "test_convnext2.r160_in1k": _cfg( + # hf_hub_id='timm/', + input_size=(3, 160, 160), pool_size=(5, 5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + "test_convnext3.r160_in1k": _cfg( + # hf_hub_id='timm/', + input_size=(3, 160, 160), pool_size=(5, 5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + }) @@ -1146,6 +1157,29 @@ def convnextv2_huge(pretrained=False, **kwargs) -> ConvNeXt: return model +@register_model +def test_convnext(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=[1, 2, 4, 2], dims=[24, 32, 48, 64], norm_eps=kwargs.pop('norm_eps', 1e-5), act_layer='gelu_tanh') + model = _create_convnext('test_convnext', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def test_convnext2(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict(depths=[1, 1, 1, 1], dims=[32, 64, 96, 128], norm_eps=kwargs.pop('norm_eps', 1e-5), act_layer='gelu_tanh') + model = _create_convnext('test_convnext2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def test_convnext3(pretrained=False, **kwargs) -> ConvNeXt: + model_args = dict( + depths=[1, 1, 1, 1], dims=[32, 64, 96, 128], norm_eps=kwargs.pop('norm_eps', 1e-5), kernel_sizes=(7, 5, 5, 3), act_layer='silu') + model = _create_convnext('test_convnext3', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + + register_model_deprecations(__name__, { 'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k', 'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k', diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index e097d8229..8c9fa416a 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -1804,6 +1804,10 @@ def _cfg(url='', **kwargs): "test_efficientnet.r160_in1k": _cfg( hf_hub_id='timm/', input_size=(3, 160, 160), pool_size=(5, 5)), + "test_efficientnet_gn.r160_in1k": _cfg( + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 160, 160), pool_size=(5, 5)), }) @@ -2792,6 +2796,12 @@ def test_efficientnet(pretrained=False, **kwargs) -> EfficientNet: return model +@register_model +def test_efficientnet_gn(pretrained=False, **kwargs) -> EfficientNet: + model = _gen_test_efficientnet( + 'test_efficientnet_gn', pretrained=pretrained, norm_layer=partial(GroupNormAct, group_size=8), **kwargs) + return model + register_model_deprecations(__name__, { 'tf_efficientnet_b0_ap': 'tf_efficientnet_b0.ap_in1k', 'tf_efficientnet_b1_ap': 'tf_efficientnet_b1.ap_in1k', diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 79cfde7ac..42aaee224 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -607,6 +607,10 @@ def _dm_nfnet_cfg( nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()), nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()), nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()), + + test_nfnet=_nfnet_cfg( + depths=(1, 1, 1, 1), channels=(32, 64, 96, 128), feat_mult=1.5, group_size=8, bottle_ratio=0.25, + attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'), ) @@ -730,6 +734,11 @@ def _dcfg(url='', **kwargs): 'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'), 'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'), 'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'), + + 'test_nfnet.r160_in1k': _dcfg( + # hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + crop_pct=0.875, input_size=(3, 160, 160), pool_size=(5, 5)), }) @@ -1029,3 +1038,8 @@ def nf_ecaresnet101(pretrained=False, **kwargs) -> NormFreeNet: """ Normalization-Free ECA-ResNet101 """ return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs) + + +@register_model +def test_nfnet(pretrained=False, **kwargs) -> NormFreeNet: + return _create_normfreenet('test_nfnet', pretrained=pretrained, **kwargs) \ No newline at end of file diff --git a/timm/models/resnet.py b/timm/models/resnet.py index a80954ccc..4bda521d7 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -16,8 +16,8 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \ - get_attn, get_act_layer, get_norm_layer, create_classifier, create_aa +from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, LayerType, create_attn, \ + get_attn, get_act_layer, get_norm_layer, create_classifier, create_aa, to_ntuple from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq @@ -286,7 +286,7 @@ def drop_blocks(drop_prob: float = 0.): def make_blocks( - block_fn: Union[BasicBlock, Bottleneck], + block_fns: Tuple[Union[BasicBlock, Bottleneck]], channels: Tuple[int, ...], block_repeats: Tuple[int, ...], inplanes: int, @@ -304,7 +304,7 @@ def make_blocks( net_block_idx = 0 net_stride = 4 dilation = prev_dilation = 1 - for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))): + for stage_idx, (block_fn, planes, num_blocks, db) in enumerate(zip(block_fns, channels, block_repeats, drop_blocks(drop_block_rate))): stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it stride = 1 if stage_idx == 0 else 2 if net_stride >= output_stride: @@ -490,8 +490,9 @@ def __init__( self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks + block_fns = to_ntuple(len(channels))(block) stage_modules, stage_feature_info = make_blocks( - block, + block_fns, channels, layers, inplanes, @@ -513,7 +514,7 @@ def __init__( self.feature_info.extend(stage_feature_info) # Head (Pooling and Classifier) - self.num_features = self.head_hidden_size = channels[-1] * block.expansion + self.num_features = self.head_hidden_size = channels[-1] * block_fns[-1].expansion self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) self.init_weights(zero_init_last=zero_init_last) @@ -1301,6 +1302,11 @@ def _gcfg(url='', **kwargs): hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth', first_conv='conv1.0'), + + 'test_resnet.r160_in1k': _cfg( + #hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 160, 160), pool_size=(5, 5), first_conv='conv1.0'), }) @@ -2040,6 +2046,16 @@ def resnetrs420(pretrained: bool = False, **kwargs) -> ResNet: return _create_resnet('resnetrs420', pretrained, **dict(model_args, **kwargs)) +@register_model +def test_resnet(pretrained: bool = False, **kwargs) -> ResNet: + """Constructs a tiny ResNet test model. + """ + model_args = dict( + block=[BasicBlock, BasicBlock, Bottleneck, BasicBlock], layers=(1, 1, 1, 1), + stem_width=16, stem_type='deep', avg_down=True, channels=(32, 48, 48, 96)) + return _create_resnet('test_resnet', pretrained, **dict(model_args, **kwargs)) + + register_model_deprecations(__name__, { 'tv_resnet34': 'resnet34.tv_in1k', 'tv_resnet50': 'resnet50.tv_in1k',