diff --git a/CV/SMILE/README.md b/CV/SMILE/README.md new file mode 100644 index 00000000..ac44a2cc --- /dev/null +++ b/CV/SMILE/README.md @@ -0,0 +1,98 @@ +# SMILE: Self-Distilled MIxup for Efficient Transfer LEarning +## Introduction + +This is the [PaddlePaddle](https://www.paddlepaddle.org.cn/) implementation of the SMILE (Spotlight on [INTERPOLATE@NeurIPS 2022](https://sites.google.com/view/interpolation-workshop?pli=1)) model for image classification. + +In this work, we propose SMILE— Self-Distilled Mixup for EffIcient Transfer LEarning. +With mixed images as inputs, SMILE regularizes the outputs of CNN feature extractors to learn +from the mixed feature vectors of inputs (sample-to-feature mixup), in addition to the mixed labels. +Specifically, SMILE incorporates a mean teacher, inherited from the pre-trained model, to provide +the feature vectors of input samples in a self-distilling fashion, and mixes up the feature vectors +accordingly via a novel triplet regularizer. The triple regularizer balances the mixup effects in both +feature and label spaces while bounding the linearity in-between samples for pre-training tasks. + + + +## Requirements +The code has been tested running under the following environments: + +* python >= 3.7 +* numpy >= 1.21 +* paddlepaddle >= 2.2 (with suitable CUDA and cuDNN version) +* visualdl + + + +## Model Training + +### step1. Download dataset files +We conduct experiments on three popular object recognition datasets: CUB-200-2011, Stanford Cars and +FGVC-Aircraft. You can download it from the official link below. + - [CUB-200-2011](http://www.vision.caltech.edu/datasets/cub_200_2011/) + - [Stanford Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html) + - [FGVC-Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/) + +Please organize your dataset in the following format. +``` +dataset +├── train +│ ├── class_001 +| | ├── 1.jpg +| | ├── 2.jpg +| | └── ... +│ ├── class_002 +| | ├── 1.jpg +| | ├── 2.jpg +| | └── ... +│ └── ... +└── test + ├── class_001 + | ├── 1.jpg + | ├── 2.jpg + | └── ... + ├── class_002 + | ├── 1.jpg + | ├── 2.jpg + | └── ... + └── ... +``` + +### step2. Finetune + +You can use the following command to finetune the target data using the SMILE algorithm. Log files and ckpts during training are saved in the ./output. Only the model with the highest accuracy on the validation set is saved during finetuning. +``` +python finetune.py --name {name of your experiment} --train_dir {path of train dir} --eval_dir {path of eval dir} --model_arch resnet50 --gpu {gpu id} --regularizer smile +``` + +### step3. Test + +You can also load the finetuning ckpts with the following command and test it on the test set. +``` +python test.py --test_dir {path of test dir} --model_arch resnet50 --gpu {gpu id} --ckpts {path of finetuning ckpts} +``` + +## Results + +|Dataset/Method | L2 | SMILE | +|---|---|---| +|CUB-200-2011 | 80.79 | 82.38 | +|Stanford-Cars| 90.72 | 91.74 | +|FGVC-Aircraft| 86.93 | 89.00 | + + + +## Citation +If you use any source code included in this project in your work, please cite the following paper: + +``` +@article{Li2021SMILESM, + title={SMILE: Self-Distilled MIxup for Efficient Transfer LEarning}, + author={Xingjian Li and Haoyi Xiong and Chengzhong Xu and Dejing Dou}, + journal={ArXiv}, + year={2021}, + volume={abs/2103.13941} +} +``` + +## Copyright and License +Copyright 2019 Baidu.com, Inc. All Rights Reserved Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/CV/SMILE/backbones/__init__.py b/CV/SMILE/backbones/__init__.py new file mode 100644 index 00000000..cc306ca0 --- /dev/null +++ b/CV/SMILE/backbones/__init__.py @@ -0,0 +1,54 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from .resnet import ResNet # noqa: F401 +from .resnet import resnet18 # noqa: F401 +from .resnet import resnet34 # noqa: F401 +from .resnet import resnet50 # noqa: F401 +from .resnet import resnet101 # noqa: F401 +from .resnet import resnet152 # noqa: F401 +from .mobilenetv2 import MobileNetV2 # noqa: F401 +from .mobilenetv2 import mobilenet_v2 # noqa: F401 +from .vit import VisionTransformer +from .vit import build_vit +'''from .mobilenetv1 import MobileNetV1 # noqa: F401 +from .mobilenetv1 import mobilenet_v1 # noqa: F401 +from .mobilenetv2 import MobileNetV2 # noqa: F401 +from .mobilenetv2 import mobilenet_v2 # noqa: F401 +from .vgg import VGG # noqa: F401 +from .vgg import vgg11 # noqa: F401 +from .vgg import vgg13 # noqa: F401 +from .vgg import vgg16 # noqa: F401 +from .vgg import vgg19 # noqa: F401 +from .lenet import LeNet # noqa: F401''' + +__all__ = [ #noqa + 'ResNet', + 'resnet18', + 'resnet34', + 'resnet50', + 'resnet101', + 'resnet152', + 'VGG', + 'vgg11', + 'vgg13', + 'vgg16', + 'vgg19', + 'MobileNetV1', + 'mobilenet_v1', + 'MobileNetV2', + 'mobilenet_v2', + 'LeNet', + 'ViT' +] diff --git a/CV/SMILE/backbones/mobilenetv2.py b/CV/SMILE/backbones/mobilenetv2.py new file mode 100644 index 00000000..0af6a66c --- /dev/null +++ b/CV/SMILE/backbones/mobilenetv2.py @@ -0,0 +1,228 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +import paddle.nn as nn +import paddle.nn.functional as F + +from paddle.utils.download import get_weights_path_from_url + +__all__ = [] + +model_urls = { + 'mobilenetv2_1.0': + ('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0.pdparams', + '0340af0a901346c8d46f4529882fb63d') +} + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, + in_planes, + out_planes, + kernel_size=3, + stride=1, + groups=1, + norm_layer=nn.BatchNorm2D): + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2D( + in_planes, + out_planes, + kernel_size, + stride, + padding, + groups=groups, + bias_attr=False), + norm_layer(out_planes), + nn.ReLU6()) + + +class InvertedResidual(nn.Layer): + def __init__(self, + inp, + oup, + stride, + expand_ratio, + norm_layer=nn.BatchNorm2D): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + layers.append( + ConvBNReLU( + inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) + layers.extend([ + ConvBNReLU( + hidden_dim, + hidden_dim, + stride=stride, + groups=hidden_dim, + norm_layer=norm_layer), + nn.Conv2D( + hidden_dim, oup, 1, 1, 0, bias_attr=False), + norm_layer(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Layer): + def __init__(self, scale=1.0, num_classes=1000, with_pool=True): + """MobileNetV2 model from + `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. + + Args: + scale (float): scale of channels in each layer. Default: 1.0. + num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + will not be defined. Default: 1000. + with_pool (bool): use pool before the last fc layer or not. Default: True. + + Examples: + .. code-block:: python + + from paddle.vision.models import MobileNetV2 + + model = MobileNetV2() + """ + super(MobileNetV2, self).__init__() + self.num_classes = num_classes + self.with_pool = with_pool + input_channel = 32 + last_channel = 1280 + + block = InvertedResidual + round_nearest = 8 + norm_layer = nn.BatchNorm2D + inverted_residual_setting = [ + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + input_channel = _make_divisible(input_channel * scale, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, scale), + round_nearest) + features = [ + ConvBNReLU( + 3, input_channel, stride=2, norm_layer=norm_layer) + ] + + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * scale, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append( + block( + input_channel, + output_channel, + stride, + expand_ratio=t, + norm_layer=norm_layer)) + input_channel = output_channel + + features.append( + ConvBNReLU( + input_channel, + self.last_channel, + kernel_size=1, + norm_layer=norm_layer)) + + self.features = nn.Sequential(*features) + + if with_pool: + self.pool2d_avg = nn.AdaptiveAvgPool2D(1) + + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Dropout(0.2), nn.Linear(self.last_channel, num_classes)) + def forward(self, x): + fea = self.features(x) + + if self.with_pool: + x = self.pool2d_avg(fea) + else: + x = fea + + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.classifier(x) + return x, fea + + +def _mobilenet(arch, pretrained=False, **kwargs): + model = MobileNetV2(**kwargs) + if pretrained: + assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( + arch) + weight_path = get_weights_path_from_url(model_urls[arch][0], + model_urls[arch][1]) + + param = paddle.load(weight_path) + model.load_dict(param) + + return model + + +def mobilenet_v2(pretrained=False, scale=1.0, **kwargs): + """MobileNetV2 + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False. + scale: (float): scale of channels in each layer. Default: 1.0. + + Examples: + .. code-block:: python + + from paddle.vision.models import mobilenet_v2 + + # build model + model = mobilenet_v2() + + # build model and load imagenet pretrained weight + # model = mobilenet_v2(pretrained=True) + + # build mobilenet v2 with scale=0.5 + model = mobilenet_v2(scale=0.5) + """ + model = _mobilenet( + 'mobilenetv2_' + str(scale), pretrained, scale=scale, **kwargs) + return model diff --git a/CV/SMILE/backbones/resnet.py b/CV/SMILE/backbones/resnet.py new file mode 100644 index 00000000..854767c8 --- /dev/null +++ b/CV/SMILE/backbones/resnet.py @@ -0,0 +1,376 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + +from paddle.utils.download import get_weights_path_from_url + +__all__ = [] + +model_urls = { + 'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams', + 'cf548f46534aa3560945be4b95cd11c4'), + 'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams', + '8d2275cf8706028345f78ac0e1d31969'), + #'resnet50': ('https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_pretrained.pdparams', + # 'bd3377719169052e331747d7e139bee2'), + 'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams', + 'ca6f485ee1ab0492d38f323885b0ad80'), + 'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams', + '02f35f034ca3858e1e54d4036443c92d'), + 'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams', + '7ad16a2f1e7333859ff986138630fd7a'), +} + + +class BasicBlock(nn.Layer): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2D + + if dilation > 1: + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock") + + self.conv1 = nn.Conv2D( + inplanes, planes, 3, padding=1, stride=stride, bias_attr=False) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class BottleneckBlock(nn.Layer): + + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BottleneckBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2D + width = int(planes * (base_width / 64.)) * groups + + self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False) + self.bn1 = norm_layer(width) + + self.conv2 = nn.Conv2D( + width, + width, + 3, + padding=dilation, + stride=stride, + groups=groups, + dilation=dilation, + bias_attr=False) + self.bn2 = norm_layer(width) + + self.conv3 = nn.Conv2D( + width, planes * self.expansion, 1, bias_attr=False) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Layer): + """ResNet model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + Block (BasicBlock|BottleneckBlock): block module of model. + depth (int): layers of resnet, default: 50. + num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + will not be defined. Default: 1000. + with_pool (bool): use pool before the last fc layer or not. Default: True. + + Examples: + .. code-block:: python + + from paddle.vision.models import ResNet + from paddle.vision.models.resnet import BottleneckBlock, BasicBlock + + resnet50 = ResNet(BottleneckBlock, 50) + + resnet18 = ResNet(BasicBlock, 18) + + """ + + def __init__(self, block, depth, num_classes=1000, with_pool=True, aux_head=False): + super(ResNet, self).__init__() + layer_cfg = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3] + } + layers = layer_cfg[depth] + self.num_classes = num_classes + self.with_pool = with_pool + self.aux_head = aux_head + self._norm_layer = nn.BatchNorm2D + + self.inplanes = 64 + self.dilation = 1 + + self.conv1 = nn.Conv2D( + 3, + self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias_attr=False) + self.bn1 = self._norm_layer(self.inplanes) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + if with_pool: + self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) + + if num_classes > 0: + self.fc = nn.Linear(512 * block.expansion, num_classes) + + if self.aux_head: + self.aux_fc = nn.Linear(512 * block.expansion, 1000) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2D( + self.inplanes, + planes * block.expansion, + 1, + stride=stride, + bias_attr=False), + norm_layer(planes * block.expansion), ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, 1, 64, + previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + fea = self.layer4(x) + + if self.with_pool: + x = self.avgpool(fea) + + if self.num_classes > 0: + x0 = paddle.flatten(x, 1) + x = self.fc(x0) + if self.aux_head: + aux_logits = self.aux_fc(x0) + if self.aux_head: + return x, fea, aux_logits + else: + return x, fea + + +def _resnet(arch, Block, depth, pretrained, **kwargs): + model = ResNet(Block, depth, **kwargs) + if pretrained: + assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(arch) + weight_path = get_weights_path_from_url(model_urls[arch][0], + model_urls[arch][1]) + + param = paddle.load(weight_path) + model.set_dict(param) + + return model + + +def resnet18(pretrained=False, **kwargs): + """ResNet 18-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet18 + + # build model + model = resnet18() + + # build model and load imagenet pretrained weight + # model = resnet18(pretrained=True) + """ + return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs) + + +def resnet34(pretrained=False, **kwargs): + """ResNet 34-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet34 + + # build model + model = resnet34() + + # build model and load imagenet pretrained weight + # model = resnet34(pretrained=True) + """ + return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs) + + +def resnet50(pretrained=False, **kwargs): + """ResNet 50-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet50 + + # build model + model = resnet50() + + # build model and load imagenet pretrained weight + # model = resnet50(pretrained=True) + """ + return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs) + + +def resnet101(pretrained=False, **kwargs): + """ResNet 101-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet101 + + # build model + model = resnet101() + + # build model and load imagenet pretrained weight + # model = resnet101(pretrained=True) + """ + return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs) + + +def resnet152(pretrained=False, **kwargs): + """ResNet 152-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet152 + + # build model + model = resnet152() + + # build model and load imagenet pretrained weight + # model = resnet152(pretrained=True) + """ + return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs) diff --git a/CV/SMILE/backbones/vit.py b/CV/SMILE/backbones/vit.py new file mode 100644 index 00000000..8bcaa726 --- /dev/null +++ b/CV/SMILE/backbones/vit.py @@ -0,0 +1,426 @@ +# Copyright (c) 2021 PPViT Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +ViT in Paddle +A Paddle Implementation of Vision Transformer (ViT) as described in: +"An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale" + - Paper Link: https://arxiv.org/abs/2010.11929 +""" +import paddle +import paddle.nn as nn + + +class Identity(nn.Layer): + """ Identity layer + The output of this layer is the input without any change. + This layer is used to avoid using 'if' condition in methods such as forward + """ + def forward(self, x): + return x + + +class PatchEmbedding(nn.Layer): + """Patch Embedding + Apply patch embedding (which is implemented using Conv2D) on input data. + Attributes: + image_size: image size + patch_size: patch size + num_patches: num of patches + patch_embddings: patch embed operation (Conv2D) + """ + def __init__(self, + image_size=224, + patch_size=16, + in_channels=3, + embed_dim=768): + super().__init__() + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = (image_size // patch_size) * (image_size // patch_size) + self.patch_embedding = nn.Conv2D(in_channels=in_channels, + out_channels=embed_dim, + kernel_size=patch_size, + stride=patch_size) + def forward(self, x): + x = self.patch_embedding(x) + x = x.flatten(2) # [B, C, H, W] -> [B, C, h*w] + x = x.transpose([0, 2, 1]) # [B, C, h*w] -> [B, h*w, C] = [B, N, C] + return x + + +class Attention(nn.Layer): + """ Attention module + Attention module for ViT, here q, k, v are assumed the same. + The qkv mappings are stored as one single param. + Attributes: + num_heads: number of heads + attn_head_size: feature dim of single head + all_head_size: feature dim of all heads + qkv: a nn.Linear for q, k, v mapping + scales: 1 / sqrt(single_head_feature_dim) + out: projection of multi-head attention + attn_dropout: dropout for attention + proj_dropout: final dropout before output + softmax: softmax op for attention + """ + def __init__(self, + embed_dim, + num_heads, + attn_head_size=None, + qkv_bias=True, + dropout=0., + attention_dropout=0.): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + if attn_head_size is not None: + self.attn_head_size = attn_head_size + else: + assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + self.attn_head_size = embed_dim // num_heads + self.all_head_size = self.attn_head_size * num_heads + + w_attr_1, b_attr_1 = self._init_weights() + self.qkv = nn.Linear(embed_dim, + self.all_head_size * 3, # weights for q, k, and v + weight_attr=w_attr_1, + bias_attr=b_attr_1 if qkv_bias else False) + + self.scales = self.attn_head_size ** -0.5 + + w_attr_2, b_attr_2 = self._init_weights() + self.out = nn.Linear(self.all_head_size, + embed_dim, + weight_attr=w_attr_2, + bias_attr=b_attr_2) + + self.attn_dropout = nn.Dropout(attention_dropout) + self.proj_dropout = nn.Dropout(dropout) + self.softmax = nn.Softmax(axis=-1) + + def _init_weights(self): + weight_attr = paddle.ParamAttr(initializer=nn.initializer.TruncatedNormal(std=.02)) + bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0)) + return weight_attr, bias_attr + + def transpose_multihead(self, x): + """[B, N, C] -> [B, N, n_heads, head_dim] -> [B, n_heads, N, head_dim]""" + new_shape = x.shape[:-1] + [self.num_heads, self.attn_head_size] + x = x.reshape(new_shape) # [B, N, C] -> [B, N, n_heads, head_dim] + x = x.transpose([0, 2, 1, 3]) # [B, N, n_heads, head_dim] -> [B, n_heads, N, head_dim] + return x + + def forward(self, x): + qkv = self.qkv(x).chunk(3, axis=-1) + q, k, v = map(self.transpose_multihead, qkv) + + q = q * self.scales + attn = paddle.matmul(q, k, transpose_y=True) # [B, n_heads, N, N] + attn = self.softmax(attn) + attn = self.attn_dropout(attn) + + z = paddle.matmul(attn, v) # [B, n_heads, N, head_dim] + z = z.transpose([0, 2, 1, 3]) # [B, N, n_heads, head_dim] + new_shape = z.shape[:-2] + [self.all_head_size] + z = z.reshape(new_shape) # [B, N, all_head_size] + + z = self.out(z) + z = self.proj_dropout(z) + return z + + +class Mlp(nn.Layer): + """ MLP module + Impl using nn.Linear and activation is GELU, dropout is applied. + Ops: fc -> act -> dropout -> fc -> dropout + Attributes: + fc1: nn.Linear + fc2: nn.Linear + act: GELU + dropout: dropout after fc + """ + + def __init__(self, + embed_dim, + mlp_ratio, + dropout=0.): + super().__init__() + w_attr_1, b_attr_1 = self._init_weights() + self.fc1 = nn.Linear(embed_dim, + int(embed_dim * mlp_ratio), + weight_attr=w_attr_1, + bias_attr=b_attr_1) + + w_attr_2, b_attr_2 = self._init_weights() + self.fc2 = nn.Linear(int(embed_dim * mlp_ratio), + embed_dim, + weight_attr=w_attr_2, + bias_attr=b_attr_2) + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def _init_weights(self): + weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.TruncatedNormal(std=0.2)) + bias_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0)) + return weight_attr, bias_attr + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class TransformerLayer(nn.Layer): + """Transformer Layer + Transformer layer contains attention, norm, mlp and residual + Attributes: + embed_dim: transformer feature dim + attn_norm: nn.LayerNorm before attention + mlp_norm: nn.LayerNorm before mlp + mlp: mlp modual + attn: attention modual + """ + def __init__(self, + embed_dim, + num_heads, + attn_head_size=None, + qkv_bias=True, + mlp_ratio=4., + dropout=0., + attention_dropout=0., + droppath=0.): + super().__init__() + w_attr_1, b_attr_1 = self._init_weights() + self.attn_norm = nn.LayerNorm(embed_dim, + weight_attr=w_attr_1, + bias_attr=b_attr_1, + epsilon=1e-6) + + self.attn = Attention(embed_dim, + num_heads, + attn_head_size, + qkv_bias, + dropout, + attention_dropout) + + #self.drop_path = DropPath(droppath) if droppath > 0. else Identity() + + w_attr_2, b_attr_2 = self._init_weights() + self.mlp_norm = nn.LayerNorm(embed_dim, + weight_attr=w_attr_2, + bias_attr=b_attr_2, + epsilon=1e-6) + + self.mlp = Mlp(embed_dim, mlp_ratio, dropout) + + def _init_weights(self): + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(1.0)) + bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0)) + return weight_attr, bias_attr + + def forward(self, x): + h = x + x = self.attn_norm(x) + x = self.attn(x) + #x = self.drop_path(x) + x = x + h + + h = x + x = self.mlp_norm(x) + x = self.mlp(x) + #x = self.drop_path(x) + x = x + h + + return x + + +class Encoder(nn.Layer): + """Transformer encoder + Encoder encoder contains a list of TransformerLayer, and a LayerNorm. + Attributes: + layers: nn.LayerList contains multiple EncoderLayers + encoder_norm: nn.LayerNorm which is applied after last encoder layer + """ + def __init__(self, + embed_dim, + num_heads, + depth, + attn_head_size=None, + qkv_bias=True, + mlp_ratio=4.0, + dropout=0., + attention_dropout=0., + droppath=0.): + super().__init__() + # stochatic depth decay + depth_decay = [x.item() for x in paddle.linspace(0, droppath, depth)] + + layer_list = [] + for i in range(depth): + layer_list.append(TransformerLayer(embed_dim, + num_heads, + attn_head_size, + qkv_bias, + mlp_ratio, + dropout, + attention_dropout, + depth_decay[i])) + self.layers = nn.LayerList(layer_list) + + w_attr_1, b_attr_1 = self._init_weights() + self.encoder_norm = nn.LayerNorm(embed_dim, + weight_attr=w_attr_1, + bias_attr=b_attr_1, + epsilon=1e-6) + + def _init_weights(self): + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(1.0)) + bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0)) + return weight_attr, bias_attr + + def forward(self, x): + for layer in self.layers: + x = layer(x) + x = self.encoder_norm(x) + return x + + +class VisionTransformer(nn.Layer): + """ViT transformer + ViT Transformer, classifier is a single Linear layer for finetune, + For training from scratch, two layer mlp should be used. + Classification is done using cls_token. + Args: + image_size: int, input image size, default: 224 + patch_size: int, patch size, default: 16 + in_channels: int, input image channels, default: 3 + num_classes: int, number of classes for classification, default: 1000 + embed_dim: int, embedding dimension (patch embed out dim), default: 768 + depth: int, number ot transformer blocks, default: 12 + num_heads: int, number of attention heads, default: 12 + attn_head_size: int, dim of head, if none, set to embed_dim // num_heads, default: None + mlp_ratio: float, ratio of mlp hidden dim to embed dim(mlp in dim), default: 4.0 + qkv_bias: bool, If True, enable qkv(nn.Linear) layer with bias, default: True + dropout: float, dropout rate for linear layers, default: 0. + attention_dropout: float, dropout rate for attention layers default: 0. + droppath: float, droppath rate for droppath layers, default: 0. + representation_size: int, set representation layer (pre-logits) if set, default: None + """ + def __init__(self, + image_size=224, + patch_size=16, + in_channels=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + attn_head_size=None, + mlp_ratio=4, + qkv_bias=True, + dropout=0., + attention_dropout=0., + droppath=0., + representation_size=None): + super().__init__() + # create patch embedding + self.patch_embedding = PatchEmbedding(image_size, + patch_size, + in_channels, + embed_dim) + # create posision embedding + self.position_embedding = paddle.create_parameter( + shape=[1, 1 + self.patch_embedding.num_patches, embed_dim], + dtype='float32', + default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02)) + # create cls token + self.cls_token = paddle.create_parameter( + shape=[1, 1, embed_dim], + dtype='float32', + default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02)) + self.pos_dropout = nn.Dropout(dropout) + # create multi head self-attention layers + self.encoder = Encoder(embed_dim, + num_heads, + depth, + attn_head_size, + qkv_bias, + mlp_ratio, + dropout, + attention_dropout, + droppath) + # pre-logits + if representation_size is not None: + self.num_features = representation_size + w_attr_1, b_attr_1 = self._init_weights() + self.pre_logits = nn.Sequential( + nn.Linear(embed_dim, + representation_size, + weight_attr=w_attr_1, + bias_attr=b_attr_1), + nn.ReLU()) + else: + self.pre_logits = Identity() + + # classifier head + w_attr_2, b_attr_2 = self._init_weights() + self.classifier = nn.Linear(embed_dim, + num_classes, + weight_attr=w_attr_2, + bias_attr=b_attr_2) + + def _init_weights(self): + weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(1.0)) + bias_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0)) + return weight_attr, bias_attr + + def forward_features(self, x): + x = self.patch_embedding(x) + cls_tokens = self.cls_token.expand((x.shape[0], -1, -1)) + x = paddle.concat((cls_tokens, x), axis=1) + x = x + self.position_embedding + x = self.pos_dropout(x) + x = self.encoder(x) + x = self.pre_logits(x[:, 0]) # cls_token only + return x + + def forward(self, x): + x = self.forward_features(x) + logits = self.classifier(x) + return logits + + +def build_vit(config): + """build vit model from config""" + model = VisionTransformer(image_size=config.DATA.IMAGE_SIZE, + patch_size=config.MODEL.PATCH_SIZE, + in_channels=config.DATA.IMAGE_CHANNELS, + num_classes=config.MODEL.NUM_CLASSES, + embed_dim=config.MODEL.EMBED_DIM, + depth=config.MODEL.DEPTH, + num_heads=config.MODEL.NUM_HEADS, + attn_head_size=config.MODEL.ATTN_HEAD_SIZE, + mlp_ratio=config.MODEL.MLP_RATIO, + qkv_bias=config.MODEL.QKV_BIAS, + dropout=config.MODEL.DROPOUT, + attention_dropout=config.MODEL.ATTENTION_DROPOUT, + droppath=config.MODEL.DROPPATH, + representation_size=None) + return model \ No newline at end of file diff --git a/CV/SMILE/config.py b/CV/SMILE/config.py new file mode 100644 index 00000000..81d46bc0 --- /dev/null +++ b/CV/SMILE/config.py @@ -0,0 +1,152 @@ +# Copyright (c) 2021 PPViT Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration +Configurations for (1) data processing, (2) model archtecture, and (3) training settings, etc. +Config can be set by .yaml file or by argparser +""" +import os +from yacs.config import CfgNode as CN +import yaml + +_C = CN() +_C.BASE = [''] + +# data settings +_C.DATA = CN() +_C.DATA.BATCH_SIZE = 256 # train batch_size on single GPU +_C.DATA.BATCH_SIZE_EVAL = None # (disabled in update_config) val batch_size on single GPU +_C.DATA.DATA_PATH = '/dataset/imagenet/' # path to dataset +_C.DATA.DATASET = 'imagenet2012' # dataset name, currently only support imagenet2012 +_C.DATA.IMAGE_SIZE = 224 # input image size e.g., 224 +_C.DATA.IMAGE_CHANNELS = 3 # input image channels: e.g., 3 +_C.DATA.CROP_PCT = 0.875 # input image scale ratio, scale is applied before centercrop in eval mode +_C.DATA.NUM_WORKERS = 1 # number of data loading threads +_C.DATA.IMAGENET_MEAN = [0.5, 0.5, 0.5] #[0.485, 0.456, 0.406] # imagenet mean values +_C.DATA.IMAGENET_STD = [0.5, 0.5, 0.5] #[0.229, 0.224, 0.225] # imagenet std values + +# model general settings +_C.MODEL = CN() +_C.MODEL.TYPE = 'vit' +_C.MODEL.NAME = 'vit' +_C.MODEL.RESUME = None # full model path for resume training +_C.MODEL.PRETRAINED = None # full model path for finetuning +_C.MODEL.NUM_CLASSES = 1000 # num of classes for classifier +_C.MODEL.DROPOUT = 0.0 +_C.MODEL.ATTENTION_DROPOUT = 0.0 +_C.MODEL.DROPPATH = 0.0 +# model transformer settings +_C.MODEL.PATCH_SIZE = 16 +_C.MODEL.EMBED_DIM = 768 +_C.MODEL.NUM_HEADS = 12 +_C.MODEL.ATTN_HEAD_SIZE = None # if None, use embed_dim // num_heads as head dim +_C.MODEL.DEPTH = 12 +_C.MODEL.MLP_RATIO = 4.0 +_C.MODEL.QKV_BIAS = True + +# training settings (for ViT-B/16 pretrain) +_C.TRAIN = CN() +_C.TRAIN.LAST_EPOCH = 0 +_C.TRAIN.NUM_EPOCHS = 300 +_C.TRAIN.WARMUP_EPOCHS = 32 +_C.TRAIN.WEIGHT_DECAY = 0.3 +_C.TRAIN.BASE_LR = 3e-3 +_C.TRAIN.WARMUP_START_LR = 1e-6 +_C.TRAIN.END_LR = 0.0 +_C.TRAIN.GRAD_CLIP = None +_C.TRAIN.ACCUM_ITER = 1 + +# optimizer +_C.TRAIN.OPTIMIZER = CN() +_C.TRAIN.OPTIMIZER.NAME = 'AdamW' +_C.TRAIN.OPTIMIZER.EPS = 1e-8 +_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) + + +# misc +_C.SAVE = "./output" # output folder, saves logs and weights +_C.SAVE_FREQ = 10 # freq to save chpt +_C.REPORT_FREQ = 20 # freq to logging info +_C.VALIDATE_FREQ = 1 # freq to do validation +_C.SEED = 0 # random seed +_C.EVAL = False # run evaluation only +_C.AMP = False # auto mix precision training + + +def _update_config_from_file(config, cfg_file): + """Load cfg file (.yaml) and update config object + Args: + config: config object + cfg_file: config file (.yaml) + Return: + None + """ + config.defrost() + with open(cfg_file, 'r') as infile: + yaml_cfg = yaml.load(infile, Loader=yaml.FullLoader) + for cfg in yaml_cfg.setdefault('BASE', ['']): + if cfg: + _update_config_from_file( + config, os.path.join(os.path.dirname(cfg_file), cfg) + ) + config.merge_from_file(cfg_file) + config.freeze() + + +def update_config(config, args): + """Update config by ArgumentParser + Configs that are often used can be updated from arguments + Args: + args: ArgumentParser contains options + Return: + config: updated config + """ + if args.cfg: + _update_config_from_file(config, args.cfg) + config.defrost() + if args.dataset: + config.DATA.DATASET = args.dataset + if args.batch_size: + config.DATA.BATCH_SIZE = args.batch_size + config.DATA.BATCH_SIZE_EVAL = args.batch_size + if args.batch_size_eval: + config.DATA.BATCH_SIZE_EVAL = args.batch_size_eval + if args.image_size: + config.DATA.IMAGE_SIZE = args.image_size + if args.accum_iter: + config.TRAIN.ACCUM_ITER = args.accum_iter + if args.data_path: + config.DATA.DATA_PATH = args.data_path + if args.output: + config.SAVE = args.output + if args.eval: + config.EVAL = True + if args.pretrained: + config.MODEL.PRETRAINED = args.pretrained + if args.resume: + config.MODEL.RESUME = args.resume + if args.last_epoch: + config.TRAIN.LAST_EPOCH = args.last_epoch + if args.amp: # only for training + config.AMP = not config.EVAL + config.freeze() + return config + + +def get_config(cfg_file=None): + """Return a clone of config and optionally overwrite it from yaml file""" + config = _C.clone() + if cfg_file: + _update_config_from_file(config, cfg_file) + return config \ No newline at end of file diff --git a/CV/SMILE/finetune.py b/CV/SMILE/finetune.py new file mode 100644 index 00000000..50d142d7 --- /dev/null +++ b/CV/SMILE/finetune.py @@ -0,0 +1,385 @@ +import time +import logging +import os +import sys +import argparse +import random +from visualdl import LogWriter +import numpy as np + +import paddle +import paddle.nn.functional as F +import paddle.nn as nn +from paddle.vision import transforms +from paddle.vision.datasets import DatasetFolder + +from backbones import mobilenet_v2, resnet18, resnet34, resnet50, resnet101, resnet152 + + + +def get_args(): + parser = argparse.ArgumentParser(description='PaddlePaddle Deep Transfer Learning Toolkit, Image Classification Fine-tuning Example') + parser.add_argument('--name', type = str, default = 'flower102') + parser.add_argument('--train_dir', default='../CoTuning/data/finetune/flower102/train') + parser.add_argument('--eval_dir', default='../CoTuning/data/finetune/flower102/test') + parser.add_argument('--log_dir', default = './visual_log') + parser.add_argument('--save', type = str, default = './output') + parser.add_argument('--ema_decay', type = float, default = 0.999) + parser.add_argument('--model_arch', default='resnet50') + parser.add_argument('--image_size', type = int, default = 224) + parser.add_argument('--batch_size', type=int, default=48) + parser.add_argument('--batch_size_eval', type=int, default=8) + parser.add_argument('--max_iters', type=int, default=9000) + parser.add_argument('--lr', type=float, default=0.01) + parser.add_argument('--wd', type=float, default=1e-4) + parser.add_argument('--alpha', type = float, default = 0.2, help = 'coefficient of mixup') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--print_frequency', type=int, default=50) + parser.add_argument('--eval_frequency', type=int, default=500) + parser.add_argument('--seed', type=int, default=2022) + parser.add_argument('--reg_lambda', type=float, default=0.01) + parser.add_argument('--aux_lambda', type=float, default=0.1) + parser.add_argument('--cls_lambda', type=float, default=0.0001) + parser.add_argument('--regularizer', type = str, default = 'smile') + + args = parser.parse_args() + return args + + +def get_dataloader_train(args): + train_path = args.train_dir + transform_train = transforms.Compose([ + transforms.Resize(size=(256, 256)), + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(args.image_size), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + train_set = DatasetFolder(train_path, transform=transform_train) + train_loader = paddle.io.DataLoader(train_set, shuffle=True, batch_size=args.batch_size) + num_classes = len(train_set.classes) + + return train_loader, num_classes + + +def get_dataloader_val(args): + val_path = args.eval_dir + transform_val = transforms.Compose([ + transforms.Resize(size=(256, 256)), + transforms.CenterCrop(args.image_size), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + val_set = DatasetFolder(val_path, transform=transform_val) + val_loader = paddle.io.DataLoader(val_set, shuffle=False, batch_size=args.batch_size_eval) + + return val_loader + + + +def get_logger(filename, logger_name=None): + """set logging file and format + Args: + filename: str, full path of the logger file to write + logger_name: str, the logger name, e.g., 'master_logger', 'local_logger' + Return: + logger: python logger + """ + log_format = "%(asctime)s %(message)s" + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt="%m%d %I:%M:%S %p") + # different name is needed when creating multiple logger in one process + logger = logging.getLogger(logger_name) + fh = logging.FileHandler(os.path.join(filename)) + fh.setFormatter(logging.Formatter(log_format)) + logger.addHandler(fh) + return logger + + + +def mixup_data(x, y, index=None, alpha=0.2): + if alpha > 0.: + lam = np.random.beta(alpha, alpha) + else: + lam = 1. + if lam < 0.5: + lam = 1 - lam + batch_size = x.shape[0] + if index is None: + index = paddle.randperm(batch_size).numpy() + mixed_x = lam * x + (1 - lam) * x[index] + y_a, y_b = y, y[index] + return mixed_x, y_a, y_b, lam + +def mixup_criterion_hard(criterion, pred, y_a, y_b, lam): + return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) + +logsoftmax = paddle.nn.LogSoftmax(axis=1) +def mixup_criterion_soft(pred, y_a, y_b, lam): + log_probs = logsoftmax(pred) + loss_a = (-y_a * log_probs).mean(0).sum() + loss_b = (-y_b * log_probs).mean(0).sum() + loss = lam * loss_a + (1 - lam) * loss_b + return loss + +def feature_interpolation(fm_src, fm_tgt, lam, index_perb): + fm_src = fm_src.detach() + b, c, h, w = fm_src.shape + fm_src = lam * fm_src + (1 - lam) * fm_src[index_perb] + fea_loss = paddle.norm(fm_src - fm_tgt) / (h * w) + return fea_loss + + +def reg_fc(model): + l2_cls = 0 + for name, param in model.named_parameters(): + if name.startswith('fc.') or name.startswith('aux_fc.'): + l2_cls += 0.5 * paddle.norm(param) ** 2 + return l2_cls + +def update_mean_teacher(ema_decay, model_source, model_tgt): # debug + alpha = ema_decay + #alpha = min(1 - 1 / (args.max_iters + 1), args.ema_decay) + new_dict = {} + for name, src_param in model_source.named_parameters(): + if name.startswith('fc.'): + new_dict[name] = src_param + continue + tgt_param = model_tgt.state_dict()[name] + src_param = alpha * src_param + (1 - alpha) * tgt_param + new_dict[name] = src_param + # src_param.data.mul_(alpha).add_(1 - alpha, tgt_param.data) + # model_source.state_dict()[name].set_dict(alpha * src_param + (1 - alpha) * tgt_param) + model_source.set_dict(new_dict) + + + + + +def train(iter_tgt, + model_source, + model_tgt, + reg_lambda, + aux_lambda, + cls_lambda, + alpha, + criterion, + ema_decay, + optimizer, + cur_iter, + total_iter, + debug_steps=100, + logger=None, + cur_regularizer='smile'): + + + model_tgt.train() + time_st = time.time() + + data = iter_tgt.next() + image = data[0] + label = paddle.unsqueeze(data[1], 1) + + if cur_regularizer == 'smile': + # mix up + index_perm = paddle.randperm(image.shape[0]).numpy() + inputs_mix, targets_a, targets_b, lam = mixup_data(image, label, index=index_perm, alpha=alpha) + logits_mix, features_mix, outputs_aux = model_tgt(inputs_mix) + loss_main = mixup_criterion_hard(criterion, logits_mix, targets_a, targets_b, lam) + loss_all = {'loss_main': loss_main} + + logits_src, feature_scr = model_source(image) + outputs_src = F.softmax(logits_src, axis=1) + loss_aux = mixup_criterion_soft(outputs_aux, outputs_src, outputs_src[index_perm], lam) + loss_all['loss_aux'] = aux_lambda * loss_aux + + loss_reg = feature_interpolation(feature_scr, features_mix, lam, index_perm) + loss_all['loss_reg'] = reg_lambda * loss_reg + + if ema_decay < 1-1e-6 and cur_iter % 10 == 0: + update_mean_teacher(ema_decay, model_source, model_tgt) + elif cur_regularizer == 'l2': + logits, _, _ = model_tgt(image) + loss_main = criterion(logits, label) + loss_all = {'loss_main': loss_main} + loss_classifier = reg_fc(model_tgt) + loss_all['loss_classifier'] = cls_lambda * loss_classifier + + loss = sum(loss_all.values()) + + + + loss.backward() + optimizer.step() + optimizer.clear_grad() + + model_tgt.eval() + with paddle.no_grad(): + logits, _, _ = model_tgt(image) + model_tgt.train() + acc = paddle.metric.accuracy(logits, label) + train_time = time.time() - time_st + if logger and cur_iter % debug_steps == 0: + logger.info( + f"Step[{cur_iter:04d}/{total_iter:04d}], " + + f"Loss is: {loss.numpy()}, " + + f"Loss all: {loss_all}" + + f"Train ACC@1: {acc.numpy()}"+ + f"Train Time: {train_time}") + return loss.numpy(), acc.numpy() + + + +def validate(dataloader, model_tgt, criterion, total_batch, debug_steps=100, logger=None): + """Validation for whole dataset + Args: + dataloader: paddle.io.DataLoader, dataloader instance + model: nn.Layer, a ViT model + criterion: nn.criterion + total_batch: int, total num of batches for one epoch + debug_steps: int, num of iters to log info, default: 100 + logger: logger for logging, default: None + Returns: + val_loss_meter.avg: float, average loss on current process/gpu + val_acc1_meter.avg: float, average top1 accuracy on current process/gpu + val_acc5_meter.avg: float, average top5 accuracy on current process/gpu + val_time: float, valitaion time + """ + model_tgt.eval() + losses = [] + accuracies = [] + time_st = time.time() + + with paddle.no_grad(): + for batch_id, data in enumerate(dataloader): + image = data[0] + label = paddle.unsqueeze(data[1], 1) + logits, _, _= model_tgt(image) + + loss = criterion(logits, label) + acc = paddle.metric.accuracy(logits, label) + accuracies.append(acc.numpy()) + losses.append(loss.numpy()) + + avg_acc, avg_loss = np.mean(accuracies), np.mean(losses) + + if logger and batch_id % debug_steps == 0 and batch_id != 0: + logger.info( + f"Val Step[{batch_id:04d}/{total_batch:04d}], " + + f"Avg Loss: {avg_loss}, " + + f"Avg Acc@1: {avg_acc}, ") + + val_time = time.time() - time_st + return avg_loss, avg_acc, val_time + + +def finetune_cnn(args): + # STEP 0: Preparation + + last_epoch = -1 + paddle.device.set_device(f'gpu:{args.gpu}') + seed = args.seed + paddle.seed(seed) + np.random.seed(seed) + random.seed(seed) + if not os.path.exists(args.save): + os.makedirs(args.save, exist_ok=True) + logger = get_logger(filename=os.path.join(args.save, f'{args.name}_{args.regularizer}.txt')) + logdir = os.path.join(args.log_dir, args.regularizer) + if not os.path.exists(logdir): + os.makedirs(logdir, exist_ok=True) + writer = LogWriter(logdir = logdir) + logger.info(f'\n{args}') + + # STEP 1: Create train and val dataloader + dataloader_train, num_classes = get_dataloader_train(args) + if os.path.exists(args.eval_dir): + dataloader_val = get_dataloader_val(args) + + # STEP 2: load model + model_source = eval(args.model_arch)(pretrained=True, num_classes = 1000) # imagenet pretrained + model_tgt = eval(args.model_arch)(pretrained=True, num_classes = num_classes, aux_head=True) # setting aux classifier (aux head) + model_tgt.aux_fc.set_dict(model_source.fc.state_dict()) # auc_fc use pretrained ckpt! + logger.info('finish load the pretrained model') + + # STEP 3: freeze model_src + # algo = determine_algo(model, args, dataloader_train) + model_source.eval() + for param in model_source.parameters(): + param.stop_gradient = True + + # STEP 4: Define optimizer and lr_scheduler + criterion = paddle.nn.CrossEntropyLoss() + params = model_tgt.parameters() + lr_scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[int(2.0*(args.max_iters+1000)/3.0)],values=[args.lr,args.lr*0.1]) + optimizer = paddle.optimizer.Momentum(learning_rate=lr_scheduler, parameters=params,momentum=0.9, use_nesterov=True, weight_decay = args.wd) + + # STEP 5: Run training + logger.info(f"Start training from iter 0.") + len_tgt = len(dataloader_train) + iter_tgt = iter(dataloader_train) + best_val_acc = 0.0 + for cur_iter in range(0, args.max_iters): + # train + cur_regularizer = args.regularizer + if args.regularizer == 'smile' and cur_iter >= args.max_iters - 1000: + cur_regularizer = 'l2' + if cur_iter % 500 == 0: + logger.info(f"Now training iter {cur_iter}. LR={optimizer.get_lr():.6f}") + if (cur_iter + 1) % len_tgt == 0: + iter_tgt = iter(dataloader_train) + train_loss, train_acc = train(iter_tgt=iter_tgt, + model_source=model_source, + model_tgt=model_tgt, + reg_lambda=args.reg_lambda, + aux_lambda=args.aux_lambda, + cls_lambda=args.cls_lambda, + alpha=args.alpha, + criterion=criterion, + ema_decay=args.ema_decay, + optimizer=optimizer, + cur_iter=cur_iter, + total_iter=args.max_iters, + debug_steps=args.print_frequency, + logger=logger, + cur_regularizer=cur_regularizer) + lr_scheduler.step() + writer.add_scalar(tag="train_acc", step=cur_iter, value=train_acc) + writer.add_scalar(tag="train_loss", step=cur_iter, value=train_loss) + + # validation and save ckpts + if (cur_iter % args.eval_frequency == 0 and cur_iter != 0) or (cur_iter+1) == args.max_iters: + logger.info(f'----- Validation after iter: {cur_iter}') + val_loss, val_acc, val_time = validate( + dataloader=dataloader_val, + model_tgt=model_tgt, + criterion=criterion, + total_batch=len(dataloader_val), + debug_steps=args.print_frequency, + logger=logger) + logger.info(f"----- Iter[{cur_iter:03d}/{args.max_iters:03d}], " + + f"Validation Loss: {val_loss:.4f}, " + + f"Validation Acc@1: {val_acc:.4f}, " + + f"time: {val_time:.2f}") + writer.add_scalar(tag="val_acc", step=cur_iter, value=val_acc) + writer.add_scalar(tag="val_loss", step=cur_iter, value=val_loss) + + # save if necessary + + if val_acc > best_val_acc: + best_val_acc = val_acc + model_path = os.path.join(args.save, f"{args.name}_{args.regularizer}_Best.pdparams") + state_dict = dict() + state_dict['model'] = model_tgt.state_dict() + state_dict['optimizer'] = optimizer.state_dict() + state_dict['iter'] = cur_iter + if lr_scheduler is not None: + state_dict['lr_scheduler'] = lr_scheduler.state_dict() + paddle.save(state_dict, model_path) + logger.info(f"----- Save model: {model_path}") + print('Current best acc on val set is: ', best_val_acc) + + +if __name__ == '__main__': + print(paddle.__version__) + args = get_args() + finetune_cnn(args) + diff --git a/CV/SMILE/image/method.png b/CV/SMILE/image/method.png new file mode 100644 index 00000000..85734de8 Binary files /dev/null and b/CV/SMILE/image/method.png differ diff --git a/CV/SMILE/setup.py b/CV/SMILE/setup.py new file mode 100644 index 00000000..fd936e1d --- /dev/null +++ b/CV/SMILE/setup.py @@ -0,0 +1,27 @@ +from setuptools import setup, find_packages +import paddletransfer + +with open("README.md", "r", encoding = "utf-8") as fh: + long_description = fh.read() + +setup( + name = "paddletransfer", + version = paddletransfer.__version__, + author = "Baidu-BDL", + author_email = "autodl@baidu.com", + description = "transfer learning toolkits for finetune deep learning models", + long_description = long_description, + long_description_content_type = "text/markdown", + classifiers = [ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Programming Language :: Python :: 3.9' + ], + packages = find_packages(), + python_requires=">=3.7", + install_requires=[ + 'numpy' + ], + license = 'Apache 2.0', + keywords = "transfer learning toolkits for paddle models" +) diff --git a/CV/SMILE/test.py b/CV/SMILE/test.py new file mode 100644 index 00000000..0dd052e7 --- /dev/null +++ b/CV/SMILE/test.py @@ -0,0 +1,150 @@ +import time +import logging +import os +import sys +import argparse +import random +from visualdl import LogWriter +import numpy as np + +import paddle +import paddle.nn.functional as F +import paddle.nn as nn +from paddle.vision import transforms +from paddle.vision.datasets import DatasetFolder + +from backbones import mobilenet_v2, resnet18, resnet34, resnet50, resnet101, resnet152 + + + +def get_args(): + parser = argparse.ArgumentParser(description='PaddlePaddle Deep Transfer Learning Toolkit, Image Classification Fine-tuning Example') + parser.add_argument('--test_dir', default='../CoTuning/data/finetune/flower102/test') + parser.add_argument('--model_arch', default='resnet50') + parser.add_argument('--ckpts', type = str) + parser.add_argument('--image_size', type = int, default = 224) + parser.add_argument('--batch_size_eval', type=int, default=8) + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--seed', type=int, default=2022) + args = parser.parse_args() + return args + + + + +def get_dataloader_test(args): + test_path = args.test_dir + transform_val = transforms.Compose([ + transforms.Resize(size=(256, 256)), + transforms.CenterCrop(args.image_size), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + test_set = DatasetFolder(test_path, transform=transform_val) + test_loader = paddle.io.DataLoader(test_set, shuffle=False, batch_size=args.batch_size_eval) + num_classes = len(test_set.classes) + + return test_loader, num_classes + + + +def get_logger(filename, logger_name=None): + """set logging file and format + Args: + filename: str, full path of the logger file to write + logger_name: str, the logger name, e.g., 'master_logger', 'local_logger' + Return: + logger: python logger + """ + log_format = "%(asctime)s %(message)s" + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt="%m%d %I:%M:%S %p") + # different name is needed when creating multiple logger in one process + logger = logging.getLogger(logger_name) + fh = logging.FileHandler(os.path.join(filename)) + fh.setFormatter(logging.Formatter(log_format)) + logger.addHandler(fh) + return logger + + + + +def test(dataloader, model_tgt, criterion, total_batch, debug_steps=100): + """Test for whole dataset + Args: + dataloader: paddle.io.DataLoader, dataloader instance + model: nn.Layer, a ViT model + criterion: nn.criterion + total_batch: int, total num of batches for one epoch + debug_steps: int, num of iters to log info, default: 100 + logger: logger for logging, default: None + Returns: + test_loss_meter.avg: float, average loss on current process/gpu + test_acc1_meter.avg: float, average top1 accuracy on current process/gpu + test_time: float, test time + """ + model_tgt.eval() + losses = [] + accuracies = [] + time_st = time.time() + + with paddle.no_grad(): + for batch_id, data in enumerate(dataloader): + image = data[0] + label = paddle.unsqueeze(data[1], 1) + logits, _, _= model_tgt(image) + + loss = criterion(logits, label) + acc = paddle.metric.accuracy(logits, label) + accuracies.append(acc.numpy()) + losses.append(loss.numpy()) + + avg_acc, avg_loss = np.mean(accuracies), np.mean(losses) + + if batch_id % debug_steps == 0 and batch_id != 0: + print( + f"Val Step[{batch_id:04d}/{total_batch:04d}], " + + f"Avg Loss: {avg_loss}, " + + f"Avg Acc@1: {avg_acc}, ") + + val_time = time.time() - time_st + return avg_loss, avg_acc, val_time + + +def test_cnn(args): + # STEP 0: Preparation + paddle.device.set_device(f'gpu:{args.gpu}') + seed = args.seed + paddle.seed(seed) + np.random.seed(seed) + random.seed(seed) + + # STEP 1: Create test dataloader + dataloader_test, num_classes = get_dataloader_test(args) + + # STEP 2: Load model + model_tgt = eval(args.model_arch)(pretrained=False, num_classes = num_classes, aux_head=True) # setting aux classifier (aux head) + loaded_dict = paddle.load(args.ckpts) + model_tgt.set_dict(loaded_dict['model']) # auc_fc use pretrained ckpt! + print('finish load the finetuning model') + + # STEP 5: Testing + criterion = paddle.nn.CrossEntropyLoss() + print("Start testing...") + + val_loss, val_acc, val_time = test( + dataloader=dataloader_test, + model_tgt=model_tgt, + criterion=criterion, + total_batch=len(dataloader_test), + debug_steps=50) + print(f"Validation Loss: {val_loss:.4f}, " + + f"Validation Acc@1: {val_acc:.4f}, " + + f"time: {val_time:.2f}") + + + +if __name__ == '__main__': + print(paddle.__version__) + args = get_args() + test_cnn(args) + diff --git a/README.md b/README.md index 6c15c1fd..b2b936ee 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ | 图像分类 | [CLPI](CV/CLPI-Collaborative-Learning-for-Diabetic-Retinopathy-Grading) | 模型利用一个Lesion Generator改善了糖尿病视网膜病变图像分级的模型性能,理论上可用于所有希望实现局部+整体模型分析的场景 | - | | 图像分类 | [RSNA-IHD](CV/Effective Transformer-based Solution for RSNA Intracranial Hemorrhage Detection) | 提出了一种有效的颅内出血检测(IHD)方法,其性能超过了在RSNA-IHD竞赛(2019)中获胜的解决方案。与此同时,与获胜者的解决方案相比,我们的模型只有其20%的参数量和10%的FLOPs | https://arxiv.org/abs/2205.07556 | | 小样本学习 | [PaddleFSL](CV/PaddleFSL) | 小样本学习工具包,可复现多个常用基线方法在多个图片分类数据集上的汇报效果 | - | +| 迁移学习 | [SMILE](CV/SMILE) | 提出了一种自蒸馏样本混合迁移学习框架,适用于小样本图片分类 | https://arxiv.org/abs/2103.13941 | ## 自然语言处理 | 任务类型 | 目录 | 简介 | 论文链接 |