Skip to content

Commit

Permalink
Allow group_size override for more efficientnet and mobilenetv3 based…
Browse files Browse the repository at this point in the history
… models
  • Loading branch information
rwightman committed Aug 21, 2024
1 parent 00c5be7 commit b9f020a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 18 deletions.
44 changes: 30 additions & 14 deletions timm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar

def _gen_mobilenet_v1(
variant, channel_multiplier=1.0, depth_multiplier=1.0,
fix_stem_head=False, head_conv=False, pretrained=False, **kwargs):
group_size=None, fix_stem_head=False, head_conv=False, pretrained=False, **kwargs):
"""
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
Paper: https://arxiv.org/abs/1801.04381
Expand All @@ -503,7 +503,12 @@ def _gen_mobilenet_v1(
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
head_features = (1024 if fix_stem_head else max(1024, round_chs_fn(1024))) if head_conv else 0
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
block_args=decode_arch_def(
arch_def,
depth_multiplier=depth_multiplier,
fix_first_last=fix_stem_head,
group_size=group_size,
),
num_features=head_features,
stem_size=32,
fix_stem=fix_stem_head,
Expand All @@ -517,7 +522,9 @@ def _gen_mobilenet_v1(


def _gen_mobilenet_v2(
variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs):
variant, channel_multiplier=1.0, depth_multiplier=1.0,
group_size=None, fix_stem_head=False, pretrained=False, **kwargs
):
""" Generate MobileNet-V2 network
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
Paper: https://arxiv.org/abs/1801.04381
Expand All @@ -533,7 +540,12 @@ def _gen_mobilenet_v2(
]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
block_args=decode_arch_def(
arch_def,
depth_multiplier=depth_multiplier,
fix_first_last=fix_stem_head,
group_size=group_size,
),
num_features=1280 if fix_stem_head else max(1280, round_chs_fn(1280)),
stem_size=32,
fix_stem=fix_stem_head,
Expand Down Expand Up @@ -764,7 +776,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0


def _gen_efficientnetv2_base(
variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
""" Creates an EfficientNet-V2 base model
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
Expand All @@ -780,7 +792,7 @@ def _gen_efficientnetv2_base(
]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=round_chs_fn(1280),
stem_size=32,
round_chs_fn=round_chs_fn,
Expand Down Expand Up @@ -831,7 +843,8 @@ def _gen_efficientnetv2_s(
return model


def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
def _gen_efficientnetv2_m(
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
""" Creates an EfficientNet-V2 Medium model
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
Expand All @@ -849,7 +862,7 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
]

model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=1280,
stem_size=24,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
Expand All @@ -861,7 +874,8 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
return model


def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
def _gen_efficientnetv2_l(
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
""" Creates an EfficientNet-V2 Large model
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
Expand All @@ -879,7 +893,7 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
]

model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=1280,
stem_size=32,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
Expand All @@ -891,7 +905,8 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
return model


def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
def _gen_efficientnetv2_xl(
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
""" Creates an EfficientNet-V2 Xtra-Large model
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
Expand All @@ -909,7 +924,7 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
]

model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=1280,
stem_size=32,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
Expand Down Expand Up @@ -1094,7 +1109,8 @@ def _gen_tinynet(
return model


def _gen_mobilenet_edgetpu(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
def _gen_mobilenet_edgetpu(
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
"""
Based on definitions in: https://github.com/tensorflow/models/tree/d2427a562f401c9af118e47af2f030a0a5599f55/official/projects/edgetpu/vision
"""
Expand Down Expand Up @@ -1170,7 +1186,7 @@ def _arch_def(chs: List[int], group_size: int):
]

model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=num_features,
stem_size=stem_size,
stem_kernel_size=stem_kernel_size,
Expand Down
12 changes: 8 additions & 4 deletions timm/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,9 @@ def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrain
return model


def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
def _gen_mobilenet_v3(
variant: str, channel_multiplier: float = 1.0, group_size=None, pretrained: bool = False, **kwargs
) -> MobileNetV3:
"""Creates a MobileNet-V3 model.
Ref impl: ?
Expand Down Expand Up @@ -533,7 +535,7 @@ def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained:
]
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
block_args=decode_arch_def(arch_def, group_size=group_size),
num_features=num_features,
stem_size=16,
fix_stem=channel_multiplier < 0.75,
Expand Down Expand Up @@ -646,7 +648,9 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool =
return model


def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
def _gen_mobilenet_v4(
variant: str, channel_multiplier: float = 1.0, group_size=None, pretrained: bool = False, **kwargs,
) -> MobileNetV3:
"""Creates a MobileNet-V4 model.
Ref impl: ?
Expand Down Expand Up @@ -877,7 +881,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
assert False, f'Unknown variant {variant}.'

model_kwargs = dict(
block_args=decode_arch_def(arch_def),
block_args=decode_arch_def(arch_def, group_size=group_size),
head_bias=False,
head_norm=True,
num_features=num_features,
Expand Down

0 comments on commit b9f020a

Please sign in to comment.