This repo is modified based on https://github.com/erikjandevries/mnist-learning-docker.
TLDR: If you don't have nvidia-docker2 installed you need to have that first, refer to https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html for more details, then all you need to do is as following:
- install python3, pip , then you need to install
pip install imageio
andpip install pillow
- clone the repo to local
- cd <project_root>/prepare, run
python download-mnist-data.py
- cd <project_root>/basic, run
./build.sh
- run
time ./cpu_run.sh
to train the model using cpu and get the total training time - run
time ./gpu_run.sh
to train the model using gpu and get the total training time - if you want to check the gpu power and temperature run
watch -n 2 nvidia-smi
in another tty during the training
Learning MNIST using TensorFlow in a Docker container
You will need the MNIST data set in order to run these Docker images
You can use the provided scripts, based on: https://github.com/datapythonista/mnist, to download the MNIST dataset from: http://yann.lecun.com/exdb/mnist/
Running the script download-mnist-data.py --out-dir=<out_dir>
downloads the training and test sets into
the required folder/file structure:
- <out_dir> / <train|test> / <label> / <image_index>.png
By default <out_dir>="/mnt/data/Data/mnist"
In this image, you will define and train a TensorFlow model using Keras. After preparing the data, you will need to run the following three scripts.
To build the Docker image, run the build.sh
script.
To train the model, run the Docker image with the train.sh
script.
Depending on your hardware, this will take a while. With an Nvidia Geforce 1070 it takes about 25s per epoch. By default the script is set up to run for 3 epochs.
To test the predictions of the model, run the Docker image with the
test_predictions.sh
script.
This will load the pretrained model from disk and test the predictions on 10 batches of image. For each failed prediction, the image will be displayed, and the true and predicted labels will be printed to the console.