From bab52f4cd3b848ef19905794316d3327d55fcd92 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 Sep 2024 17:08:57 -0700 Subject: [PATCH] Mambaout tweaks --- timm/models/mambaout.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index 5c4722378..a33554a9c 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -12,7 +12,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead +from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model @@ -318,10 +318,12 @@ def __init__( super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate + self.output_fmt = 'NHWC' if not isinstance(depths, (list, tuple)): depths = [depths] # it means the model has only one stage if not isinstance(dims, (list, tuple)): dims = [dims] + act_layer = get_act_layer(act_layer) num_stage = len(depths) self.num_stage = num_stage @@ -456,7 +458,7 @@ def checkpoint_filter_fn(state_dict, model): def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 1.0, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head.fc', **kwargs @@ -477,6 +479,7 @@ def _cfg(url='', **kwargs): 'mambaout_small_rw': _cfg(), 'mambaout_base_slim_rw': _cfg(), 'mambaout_base_plus_rw': _cfg(), + 'test_mambaout': _cfg(input_size=(3, 160, 160), pool_size=(5, 5)), } @@ -554,9 +557,26 @@ def mambaout_base_plus_rw(pretrained=False, **kwargs): depths=(3, 4, 27, 3), dims=(128, 256, 512, 768), expansion_ratio=3.0, + conv_ratio=1.5, stem_mid_norm=False, downsample='conv_nf', ls_init_value=1e-6, + act_layer='silu', head_fn='norm_mlp', ) return _create_mambaout('mambaout_base_plus_rw', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def test_mambaout(pretrained=False, **kwargs): + model_args = dict( + depths=(1, 1, 3, 1), + dims=(16, 32, 48, 64), + expansion_ratio=3, + stem_mid_norm=False, + downsample='conv_nf', + ls_init_value=1e-4, + act_layer='silu', + head_fn='norm_mlp', + ) + return _create_mambaout('test_mambaout', pretrained=pretrained, **dict(model_args, **kwargs))