Build CIFAR-10 Classifier using various frameworks

β€’ 2 min read (Updated: December 20, 2020)

In this post, I share my collection of notebooks demonstrating how to build a CIFAR-10 classifier in various deep learning frameworks. The objectives are to show how to:

  • create a simple classifier using CNNs
  • track experiments
  • use hyperparameter tuning frameworks

I started working on this post to update my old notebooks on training a classifier on GPU vs TPU. In 2018, When I first wrote about GPU vs TPU, I wanted to find out the complexity involved in converting the code to switch from one accelerator (GPU) to another (TPUs). And whether it provided any benefit out of the box. A lot has changed since then.

TensorFlow 2.0 has made things a lot simpler. The eager mode is gentle on my brain, the Keras API, as always, is fun to work with. The introduction of tf.data API makes the construction of input pipelines easy. The features such as Autotune, cache, and prefetch take care of optimizing the pipeline. The tf.distribute.Strategy makes it simpler to switch between the accelerators (GPU, TPU).

This time around I decided to cover PyTorch, PyTorch Lightning, and JAX as well. While I do have some experience working with PyTorch and Lightning, JAX is mainly there because I wanted a reason to make something in JAX.

Each card gives you some information about the notebook, training time, train and test accuracy, etc. I would advise you not to pay too much attention to the accuracy metrics because there is a slight difference in some notebooks’ augmentation pipeline. Also, It is not my intention to perform any comparison between the frameworks. They all work great and may have pros and cons.

Update: Nov 3rd, 2020

My primary workstation, the one with GTX 1080TI in the cards below, is dead. I cannot continue with the following planned notebooks for now:

  • Optuna: PyTorch & PyTorch Lightning on GTX 1080TI
  • Ray Tune: PyTorch Lightning on GTX 1080TI
  • JAX on GTX 1080TI

Notebooks Collection

TensorFlow Notebook

13 Aug 2020
TensorFlow
GTX 1080TI (GPU)
Train Acc: 73.95%
Test Acc: 67.83%
Epochs: 50
Time: 4m 38s
Tracking: comet
pretrained basic

TensorFlow + TPU Strategy

20 Sep 2020
TensorFlow
V2 8 Cores (TPU)
Train Acc: 78.29%
Test Acc: 68.49%
Epochs: 50
Time: 4m 14s
Tracking: W&B
pretrained

Keras Tuner Basic

12 Sep 2020
TensorFlow
GTX 1080TI (GPU)
Train Acc: 0.00%
Test Acc: 57.70%
Epochs: 50
Time: 1h
Keras-tuner

Keras Tuner V2

19 Sep 2020
TensorFlow
GTX 1080TI (GPU)
Train Acc: 94.09%
Test Acc: 68.67%
Epochs: 50
Time: 1h
Keras-tuner

Ray Tune

03 Oct 2020
TensorFlow
GTX 1080TI (GPU)
Train Acc: 100.00%
Test Acc: 61.70%
Epochs: 50
Time: 1h
Ray Tune

Optuna

19 Oct 2020
TensorFlow
GTX 1080TI (GPU)
Train Acc: 88.77%
Test Acc: 79.19%
Epochs: 50
Time: 16m 8s
Optuna

PyTorch Notebook

19 Aug 2020
PyTorch
GTX 1080TI (GPU)
Train Acc: 73.95%
Test Acc: 67.83%
Epochs: 50
Time: 4m 38s
Tracking: comet
pretrained

PyTorch XLA - 1 Core

29 Sep 2020
PyTorch
V2 8 Cores (TPU)
Train Acc: 74.02%
Test Acc: 69.58%
Epochs: 50
Time: 21m 21s
Tracking: W&B
pretrained

PyTorch XLA - All Cores

---
PyTorch
V2 8 Cores (TPU)
pretrained

Ray Tune

25 Oct 2020
PyTorch
GTX 1080TI (GPU)
Train Acc: 15.91%
Test Acc: 71.20%
Epochs: 50
Time: 87m 35s
Tracking: tensorboard
Ray Tune

GPU Notebook

03 Sep 2020
PyTorch Lightning
GTX 1080TI (GPU)
Train Acc: 85.15%
Test Acc: 51.99%
Epochs: 50
Time: 8m 9s
Tracking: comet
pretrained

TPU Notebook

29 Sep 2020
PyTorch Lightning
V2 8 Cores (TPU)
Train Acc: 92.30%
Test Acc: 71.20%
Epochs: 50
Time: 4m 40s
Tracking: comet
pretrained

Ray Tune - Pending...

---
PyTorch Lightning
GTX 1080TI (GPU)
Tracking: ---
Ray Tune

Coming Soon πŸ˜…

---
JAX
GTX 1080TI (GPU)

Conclusion

This collection demonstrates the flexibility of modern deep learning frameworks and how easy it has become to switch between different accelerators and experiment tracking tools. Whether you prefer TensorFlow, PyTorch, or want to explore JAX, there are multiple paths to building effective CIFAR-10 classifiers.