This repository contains a PyTorch implementation of the paper:
SWALP : Stochastic Weight Averaging for Low-Precision Training (SWALP).
Guandao Yang, Tianyi Zhang, Polina Kirichenko, Junwen Bai, Andrew Gordon Wilson, Christopher De Sa
Low precision operations can provide scalability, memory savings, portability, and energy efficiency. This paper proposes SWALP, an approach to low precision training that averages low-precision SGD iterates with a modified learning rate schedule. SWALP is easy to implement and can match the performance of full-precision SGD even with all numbers quantized down to 8 bits, including the gradient accumulators. Additionally, we show that SWALP converges arbitrarily close to the optimal solution for quadratic objectives, and to a noise ball asymptotically smaller than low precision SGD in strongly convex settings.
This repo contains the codes to replicate our experiment for CIFAR datasets with VGG16 and PreResNet164.
Please cite our work if you find this approach useful in your research:
@misc{gu2019swalp,
title={SWALP : Stochastic Weight Averaging in Low-Precision Training},
author={Guandao Yang and Tianyi Zhang and Polina Kirichenko and Junwen Bai and Andrew Gordon Wilson and Christopher De Sa},
year={2019},
eprint={1904.11943},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
- CUDA 9.0
- PyTorch version 1.0
- torchvision
- tensorflow to use tensorboard
To install other requirements through $ pip install -r requirements.txt
.
We provide scripts to run Small-block Block Floating Point experiments on CIFAR10 and CIFAR100 with VGG16 or PreResNet164. Following are scripts to reproduce experimental results.
seed=100 # Specify experiment seed.
bash exp/block_vgg_swa.sh CIFAR10 ${seed} # SWALP training on VGG16 with Small-block BFP in CIFAR10
bash exp/block_vgg_swa.sh CIFAR100 ${seed} # SWALP training on VGG16 with Small-block BFP in CIFAR100
bash exp/block_resnet_swa.sh CIFAR10 ${seed} # SWALP training on PreResNet164 with Small-block BFP in CIFAR10
bash exp/block_resnet_swa.sh CIFAR100 ${seed} # SWALP training on PreResNet164 with Small-block BFP in CIFAR100
The low-precision results (SGD-LP and SWALP) are produced by running the scripts in /exp
folder.
The full-precision results (SGD-FP and SWA-FP) are produced by running the SWA repo.
Datset | Model | SGD-FP | SWA-FP | SGD-LP | SWALP |
---|---|---|---|---|---|
CIFAR10 | VGG16 | 6.81±0.09 | 6.51±0.14 | 7.61±0.15 | 6.70±0.12 |
PreResNet164 | 4.63±0.18 | 4.03±0.10 | |||
CIFAR100 | VGG16 | 27.23±0.17 | 25.93±0.21 | 29.59±0.32 | 26.65±0.29 |
PreResNet164 | 22.20±0.57 | 19.95±0.19 |
Tianyi Zhang provides an implementation using a low-precision training framework QPyTorch in this link.
We use the SWA repo as starter template. Network architecture implementations are adapted from:
- VGG: github.com/pytorch/vision/
- PreResNet: github.com/bearpaw/pytorch-classification