Skip to content

Optimizing Deep Convolutional Neural Network with Ternarized Weights and High Accuracy

License

Notifications You must be signed in to change notification settings

dailingjun/Ternarized_Neural_Network

 
 

Repository files navigation

Ternarized Neural Network for Image Classification

This repository contains a Pytorch implementation of the paper "Optimize Deep Convolutional Neural Network with Ternarized Weights and High Accuracy".

If you find this project useful to you, please cite our work:

@article{he2018optimize,
  title={Optimize Deep Convolutional Neural Network with Ternarized Weights and High Accuracy},
  author={He, Zhezhi and Gong, Boqing and Fan, Deliang},
  journal={IEEE Winter Conference on Applications of Computer Vision (WACV)},
  year={2019}
}

Table of Contents

Dependencies:

  • Python 3.6 (Anaconda)
  • Pytorch 4.1

Usage

For training the new model or evaluating the pretrained model, please use the following command in terminal. Remeber to revise the bash code with correct dataset/model path.

CIFAR-10:

bash train_CIFAR10.sh

ImageNet:

bash train_ImageNet.sh

In order to get the bash code run correctly, in train_ImageNet.sh file, please modify the PYTHON environment, imagenet_path imagenent dataset path, and pretrained_model trained model path. Use --evaluate to get validation accuracy.

#!/usr/bin/env sh
  
PYTHON=/home/elliot/anaconda3/envs/pytorch_041/bin/python
imagenet_path=
pretrained_model=
  
############ directory to save result #############
DATE=`date +%Y-%m-%d`
  
if [ ! -d "$DIRECTORY" ]; then
    mkdir ./save
    mkdir ./save/${DATE}/
fi
  
############ Configurations ###############
model=resnet18b_fq_lq_tern_tex_4
dataset=imagenet
epochs=50
batch_size=256
optimizer=Adam
# add more labels as additional info into the saving path
label_info=test
  
$PYTHON main.py --dataset ${dataset} \
    --data_path ${imagenet_path}  --arch ${model} \ 
    --save_path ./save/${DATE}/${dataset}_${model}_${epochs}_${label_info} \
    --epochs ${epochs} --learning_rate 0.0001 --optimizer ${optimizer} \
    --schedule 30 40 45  --gammas 0.2 0.2 0.5 \
    --batch_size ${batch_size} --workers 8 --ngpu 2  \
    --print_freq 100 --decay 0.000005 \
    --resume ${pretrained_model} --evaluate\
    --model_only  --fine_tune\

Results

Trained models can be downloaded with the links provided (Google Drive).

ResNet-20/32/44/56 on CIFAR-10:

The entire network is ternarized (including first and last layer) for ResNet-20/32/44/56 on CIFAR-10. Note that, all the CIFAR-10 experiments are directly training from scratch, where no pretrained model is used. Users can ternarized the model from the pretrained model. Since CIFAR-10 is a toy dataset, I did not upload the trained model.

ResNet-20 ResNet-32 ResNet-44 ResNet-56
Full-Precison 91.70% 92.36% 92.47% 92.68%
Ternarized 91.65% 92.48% 92.71% 92.86%

AlexNet on ImageNet:

First and Last Layer Top1/Top5 Accuracy
AlexNet (Full-Precision) Full-Precision 61.78%/82.87%
AlexNet (Ternarized) Full-Precision 58.59%/80.44%
AlexNet (Ternarized) Ternarized 57.21%/79.41%

ResNet-18/34/50/101 on ImageNet:

The pretrained models of full-precision baselines are from Pytorch.

ResNet-18 ResNet-34 ResNet-50
Full-Precision 69.75%/89.07% 73.31%/91.42% 76.13%/92.86%
Ternarized 66.01%/86.78% 70.95%/89.89% 74.00%/91.77%

ResNet-18 on ImageNet with Residual Expansion Layer (REL): For reducing the accuracy drop caused by the aggresive model compression, we append the residual expansion layers to compensate the accuracy gap. Considering the aforementioned ternarized ResNet-18 is t_ex=1 (i.e. without REL).

ResNet-18 first and last layer Top1/Top5 Accuracy
t_ex=2 Tern 68.35%/88.20%
t_ex=4 Tern 69.44%/88.91%

Task list

  • Upload Trained models for CIFAR-10 and ImageNet datasets.

  • Encoding the weights of residual expansion layers to further reduce the model size (i.e., memory usage).

  • Optimizing the thresholds chosen for the residual expansion layers.

About

Optimizing Deep Convolutional Neural Network with Ternarized Weights and High Accuracy

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 98.3%
  • Shell 1.7%