Skip to content

Commit

Permalink
Mambaout tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Sep 14, 2024
1 parent 79ce89d commit bab52f4
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions timm/models/mambaout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch import nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model
Expand Down Expand Up @@ -318,10 +318,12 @@ def __init__(
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.output_fmt = 'NHWC'
if not isinstance(depths, (list, tuple)):
depths = [depths] # it means the model has only one stage
if not isinstance(dims, (list, tuple)):
dims = [dims]
act_layer = get_act_layer(act_layer)

num_stage = len(depths)
self.num_stage = num_stage
Expand Down Expand Up @@ -456,7 +458,7 @@ def checkpoint_filter_fn(state_dict, model):
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 1.0, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head.fc',
**kwargs
Expand All @@ -477,6 +479,7 @@ def _cfg(url='', **kwargs):
'mambaout_small_rw': _cfg(),
'mambaout_base_slim_rw': _cfg(),
'mambaout_base_plus_rw': _cfg(),
'test_mambaout': _cfg(input_size=(3, 160, 160), pool_size=(5, 5)),
}


Expand Down Expand Up @@ -554,9 +557,26 @@ def mambaout_base_plus_rw(pretrained=False, **kwargs):
depths=(3, 4, 27, 3),
dims=(128, 256, 512, 768),
expansion_ratio=3.0,
conv_ratio=1.5,
stem_mid_norm=False,
downsample='conv_nf',
ls_init_value=1e-6,
act_layer='silu',
head_fn='norm_mlp',
)
return _create_mambaout('mambaout_base_plus_rw', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def test_mambaout(pretrained=False, **kwargs):
model_args = dict(
depths=(1, 1, 3, 1),
dims=(16, 32, 48, 64),
expansion_ratio=3,
stem_mid_norm=False,
downsample='conv_nf',
ls_init_value=1e-4,
act_layer='silu',
head_fn='norm_mlp',
)
return _create_mambaout('test_mambaout', pretrained=pretrained, **dict(model_args, **kwargs))

0 comments on commit bab52f4

Please sign in to comment.