Deep Learning Tutorials Translated to JAX with Flax

We have translated our Deep Learning Tutorials from PyTorch to JAX with Flax

Figure 1: We have recently translated our Deep Learning Tutorials to JAX with Flax, offering 1-to-1 translations between PyTorch (Lightning) and JAX with Flax.

PyTorch is one of the most popular Deep Learning frameworks using in research on machine learning. However, another framework, JAX, has recently gained more and more popularity. But why should you learn JAX, if there are already so many other deep learning frameworks like PyTorch and TensorFlow? The short answer: because it can be extremely fast. For instance, a small GoogleNet on CIFAR10, which we discuss in detail in Tutorial 5, can be trained in JAX 3x faster than in PyTorch with a similar setup. Note that for larger models, larger batch sizes, or smaller GPUs, a considerably smaller speedup is expected, and the code has not been designed for benchmarking. Nonetheless, JAX enables this speedup by compiling functions and numerical programs for accelerators (GPU/TPU) just in time, finding the optimal utilization of the hardware. Frameworks with dynamic computation graphs like PyTorch cannot achieve the same efficiency since they cannot anticipate the next operations before the user calls them. For example, in an Inception block of GoogleNet, we apply multiple convolutional layers in parallel on the same input. JAX can optimize the execution of this layer by compiling the whole forward pass for the available accelerator and parallelizing the convolutions where possible. In contrast, when calling the first convolutional layer in PyTorch, the framework does not know that multiple convolutions on the same feature map will follow. It sends each operation one by one to the GPU, and can only adapt the execution after seeing the next Python calls. Hence, JAX can make more efficient use of the GPU than, for instance, PyTorch.

Because of that, we have recently translated our popular Deep Learning Tutorials of the DL course at the University of Amsterdam from PyTorch with PyTorch Lightning to JAX with Flax. These 1-to-1 translations allow you to see implementations of common models side-by-side, experience how you can go from PyTorch to JAX, and are more guided through a model creation. Furthermore, we also provide a simple introduction to JAX with Flax, building a very small neural network with basic JAX tools. Check them out on our RTD website!

Phillip Lippe
Phillip Lippe
PhD student in Artificial Intelligence

PhD student on temporal causality and causal representation learning.

Related