In this post, I share my collection of notebooks demonstrating how to build a CIFAR-10 classifier in various deep learning frameworks. The objectives are to show how to:
- create a simple classifier using CNNs
- track experiments
- use hyperparameter tuning frameworks
I started working on this post to update my old notebooks on training a classifier on GPU vs TPU. In 2018, When I first wrote about GPU vs TPU, I wanted to find out the complexity involved in converting the code to switch from one accelerator (GPU) to another (TPUs). And whether it provided any benefit out of the box. A lot has changed since then.
TensorFlow 2.0 has made things a lot simpler. The eager mode is gentle on my brain, the Keras API, as always, is fun to work with. The introduction of tf.data
API makes the construction of input pipelines easy. The features such as Autotune, cache, and prefetch take care of optimizing the pipeline. The tf.distribute.Strategy
makes it simpler to switch between the accelerators (GPU, TPU).
This time around I decided to cover PyTorch, PyTorch Lightning, and JAX as well. While I do have some experience working with PyTorch and Lightning, JAX is mainly there because I wanted a reason to make something in JAX 😀.
Each card gives you some information about the notebook, training time, train and test accuracy, etc. I would advise you not to pay too much attention to the accuracy metrics because there is a slight difference in some notebooks' augmentation pipeline. Also, It is not my intention to perform any comparison between the frameworks. They all work great and may have pros and cons.
Update: Nov 3rd, 2020
My primary workstation, the one with GTX 1080TI in the cards below, is dead. I cannot continue with the following planned notebooks for now:
- Optuna: PyTorch & PyTorch Lightning on GTX 1080TI
- Ray Tune: PyTorch Lightning on GTX 1080TI
- JAX * on GTX 1080TI