Skip to content

Commit

Permalink
Merge pull request #293 from zezeze97/master
Browse files Browse the repository at this point in the history
Add SMILE
  • Loading branch information
XiaoguangHu01 authored Nov 23, 2022
2 parents b3f8c06 + 51ab3ed commit 8fac5b5
Show file tree
Hide file tree
Showing 11 changed files with 1,897 additions and 0 deletions.
98 changes: 98 additions & 0 deletions CV/SMILE/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SMILE: Self-Distilled MIxup for Efficient Transfer LEarning
## Introduction

This is the [PaddlePaddle](https://www.paddlepaddle.org.cn/) implementation of the SMILE (Spotlight on [INTERPOLATE@NeurIPS 2022](https://sites.google.com/view/interpolation-workshop?pli=1)) model for image classification.

In this work, we propose SMILE— Self-Distilled Mixup for EffIcient Transfer LEarning.
With mixed images as inputs, SMILE regularizes the outputs of CNN feature extractors to learn
from the mixed feature vectors of inputs (sample-to-feature mixup), in addition to the mixed labels.
Specifically, SMILE incorporates a mean teacher, inherited from the pre-trained model, to provide
the feature vectors of input samples in a self-distilling fashion, and mixes up the feature vectors
accordingly via a novel triplet regularizer. The triple regularizer balances the mixup effects in both
feature and label spaces while bounding the linearity in-between samples for pre-training tasks.



## Requirements
The code has been tested running under the following environments:

* python >= 3.7
* numpy >= 1.21
* paddlepaddle >= 2.2 (with suitable CUDA and cuDNN version)
* visualdl



## Model Training

### step1. Download dataset files
We conduct experiments on three popular object recognition datasets: CUB-200-2011, Stanford Cars and
FGVC-Aircraft. You can download it from the official link below.
- [CUB-200-2011](http://www.vision.caltech.edu/datasets/cub_200_2011/)
- [Stanford Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html)
- [FGVC-Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)

Please organize your dataset in the following format.
```
dataset
├── train
│ ├── class_001
| | ├── 1.jpg
| | ├── 2.jpg
| | └── ...
│ ├── class_002
| | ├── 1.jpg
| | ├── 2.jpg
| | └── ...
│ └── ...
└── test
├── class_001
| ├── 1.jpg
| ├── 2.jpg
| └── ...
├── class_002
| ├── 1.jpg
| ├── 2.jpg
| └── ...
└── ...
```

### step2. Finetune

You can use the following command to finetune the target data using the SMILE algorithm. Log files and ckpts during training are saved in the ./output. Only the model with the highest accuracy on the validation set is saved during finetuning.
```
python finetune.py --name {name of your experiment} --train_dir {path of train dir} --eval_dir {path of eval dir} --model_arch resnet50 --gpu {gpu id} --regularizer smile
```

### step3. Test

You can also load the finetuning ckpts with the following command and test it on the test set.
```
python test.py --test_dir {path of test dir} --model_arch resnet50 --gpu {gpu id} --ckpts {path of finetuning ckpts}
```

## Results

|Dataset/Method | L2 | SMILE |
|---|---|---|
|CUB-200-2011 | 80.79 | 82.38 |
|Stanford-Cars| 90.72 | 91.74 |
|FGVC-Aircraft| 86.93 | 89.00 |



## Citation
If you use any source code included in this project in your work, please cite the following paper:

```
@article{Li2021SMILESM,
title={SMILE: Self-Distilled MIxup for Efficient Transfer LEarning},
author={Xingjian Li and Haoyi Xiong and Chengzhong Xu and Dejing Dou},
journal={ArXiv},
year={2021},
volume={abs/2103.13941}
}
```

## Copyright and License
Copyright 2019 Baidu.com, Inc. All Rights Reserved Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
54 changes: 54 additions & 0 deletions CV/SMILE/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from .resnet import ResNet # noqa: F401
from .resnet import resnet18 # noqa: F401
from .resnet import resnet34 # noqa: F401
from .resnet import resnet50 # noqa: F401
from .resnet import resnet101 # noqa: F401
from .resnet import resnet152 # noqa: F401
from .mobilenetv2 import MobileNetV2 # noqa: F401
from .mobilenetv2 import mobilenet_v2 # noqa: F401
from .vit import VisionTransformer
from .vit import build_vit
'''from .mobilenetv1 import MobileNetV1 # noqa: F401
from .mobilenetv1 import mobilenet_v1 # noqa: F401
from .mobilenetv2 import MobileNetV2 # noqa: F401
from .mobilenetv2 import mobilenet_v2 # noqa: F401
from .vgg import VGG # noqa: F401
from .vgg import vgg11 # noqa: F401
from .vgg import vgg13 # noqa: F401
from .vgg import vgg16 # noqa: F401
from .vgg import vgg19 # noqa: F401
from .lenet import LeNet # noqa: F401'''

__all__ = [ #noqa
'ResNet',
'resnet18',
'resnet34',
'resnet50',
'resnet101',
'resnet152',
'VGG',
'vgg11',
'vgg13',
'vgg16',
'vgg19',
'MobileNetV1',
'mobilenet_v1',
'MobileNetV2',
'mobilenet_v2',
'LeNet',
'ViT'
]
228 changes: 228 additions & 0 deletions CV/SMILE/backbones/mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle

import paddle.nn as nn
import paddle.nn.functional as F

from paddle.utils.download import get_weights_path_from_url

__all__ = []

model_urls = {
'mobilenetv2_1.0':
('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0.pdparams',
'0340af0a901346c8d46f4529882fb63d')
}


def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)

if new_v < 0.9 * v:
new_v += divisor
return new_v


class ConvBNReLU(nn.Sequential):
def __init__(self,
in_planes,
out_planes,
kernel_size=3,
stride=1,
groups=1,
norm_layer=nn.BatchNorm2D):
padding = (kernel_size - 1) // 2

super(ConvBNReLU, self).__init__(
nn.Conv2D(
in_planes,
out_planes,
kernel_size,
stride,
padding,
groups=groups,
bias_attr=False),
norm_layer(out_planes),
nn.ReLU6())


class InvertedResidual(nn.Layer):
def __init__(self,
inp,
oup,
stride,
expand_ratio,
norm_layer=nn.BatchNorm2D):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]

hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup

layers = []
if expand_ratio != 1:
layers.append(
ConvBNReLU(
inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
layers.extend([
ConvBNReLU(
hidden_dim,
hidden_dim,
stride=stride,
groups=hidden_dim,
norm_layer=norm_layer),
nn.Conv2D(
hidden_dim, oup, 1, 1, 0, bias_attr=False),
norm_layer(oup),
])
self.conv = nn.Sequential(*layers)

def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)


class MobileNetV2(nn.Layer):
def __init__(self, scale=1.0, num_classes=1000, with_pool=True):
"""MobileNetV2 model from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
Args:
scale (float): scale of channels in each layer. Default: 1.0.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
Examples:
.. code-block:: python
from paddle.vision.models import MobileNetV2
model = MobileNetV2()
"""
super(MobileNetV2, self).__init__()
self.num_classes = num_classes
self.with_pool = with_pool
input_channel = 32
last_channel = 1280

block = InvertedResidual
round_nearest = 8
norm_layer = nn.BatchNorm2D
inverted_residual_setting = [
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]

input_channel = _make_divisible(input_channel * scale, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, scale),
round_nearest)
features = [
ConvBNReLU(
3, input_channel, stride=2, norm_layer=norm_layer)
]

for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * scale, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(
block(
input_channel,
output_channel,
stride,
expand_ratio=t,
norm_layer=norm_layer))
input_channel = output_channel

features.append(
ConvBNReLU(
input_channel,
self.last_channel,
kernel_size=1,
norm_layer=norm_layer))

self.features = nn.Sequential(*features)

if with_pool:
self.pool2d_avg = nn.AdaptiveAvgPool2D(1)

if self.num_classes > 0:
self.classifier = nn.Sequential(
nn.Dropout(0.2), nn.Linear(self.last_channel, num_classes))
def forward(self, x):
fea = self.features(x)

if self.with_pool:
x = self.pool2d_avg(fea)
else:
x = fea

if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.classifier(x)
return x, fea


def _mobilenet(arch, pretrained=False, **kwargs):
model = MobileNetV2(**kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])

param = paddle.load(weight_path)
model.load_dict(param)

return model


def mobilenet_v2(pretrained=False, scale=1.0, **kwargs):
"""MobileNetV2
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
scale: (float): scale of channels in each layer. Default: 1.0.
Examples:
.. code-block:: python
from paddle.vision.models import mobilenet_v2
# build model
model = mobilenet_v2()
# build model and load imagenet pretrained weight
# model = mobilenet_v2(pretrained=True)
# build mobilenet v2 with scale=0.5
model = mobilenet_v2(scale=0.5)
"""
model = _mobilenet(
'mobilenetv2_' + str(scale), pretrained, scale=scale, **kwargs)
return model
Loading

0 comments on commit 8fac5b5

Please sign in to comment.