Deep learning changes our lives more or less every day. Whether it’s Siri or Alexa following our voice commands, real-time translation apps on our phones, or computer vision technology enabling smart tractors, warehouse robots and self-driving cars, every month seems to bring new advancements. . And almost all of these deep learning applications are written in one of three frameworks: TensorFlow, PyTorch, and JAX.
Which of these deep learning frameworks should you use? In this article, we’ll take a high-level comparative look at TensorFlow, PyTorch, and JAX. We’ll do our best to give you an idea of the types of apps that play to their strengths, as well as consider factors like community support and ease of use.
Should you use TensorFlow?
“No one was ever fired for buying IBM” was the rallying cry of computer science in the 1970s and 1980s, and the same could be said of the use of TensorFlow in the 2010s for learning in depth. But as we all know, IBM fell by the wayside as the 1990s dawned. Is TensorFlow still competitive in this new decade, seven years after its initial release in 2015?
Well, certainly. It’s not like TensorFlow has stood still this whole time. TensorFlow 1.x was all about building static graphs in a very non-Python way, but with the TensorFlow 2.x line you can also build models using the “eager” mode for immediate evaluation of operations, which makes things much more enjoyable. like PyTorch. At the top level, TensorFlow gives you Keras for easier development, and at the low level, it gives you the Accelerated Linear Algebra (XLA) optimization compiler for speed. XLA works wonders for increasing performance on GPUs, and it’s the primary method for harnessing the power of Google’s TPUs (Tensor Processing Units), which deliver unparalleled performance for large-scale model training.
Then there are all the things that TensorFlow has been doing well for years. Do you need to serve models in a well-defined and repeatable way on a mature platform? TensorFlow Serving is here for you. Do you need to retarget your model deployments for the web, or for low-power computing such as smartphones, or for resource-constrained devices like IoT devices? Both TensorFlow.js and TensorFlow Lite are very mature at this point. And of course, since Google still runs 100% of its production deployments using TensorFlow, you can be sure that TensorFlow can handle your scale.
But…well, there was a certain lack of energy around the project that’s kind of hard to ignore these days. The upgrade from TensorFlow 1.x to TensorFlow 2.x was, in a word, brutal. Some companies have looked at the effort required to update their code to work properly on the new major release and decided to port their code to PyTorch instead. TensorFlow has also lost steam in the research community, which began to prefer the flexibility offered by PyTorch a few years ago, leading to a decline in the use of TensorFlow in research papers.
The Keras case didn’t help either. Keras became an integral part of TensorFlow releases two years ago, but was recently retired to a separate library with its own release schedule once again. Of course, splitting Keras doesn’t affect a developer’s daily life, but such a high-profile reversal in a minor framework revision doesn’t inspire confidence.
That said, TensorFlow is a reliable framework and hosts a large ecosystem for deep learning. You can build apps and models on TensorFlow that work at any scale, and you’ll be in good company if you do. But TensorFlow might not be your first choice these days.
Should you use PyTorch?
PyTorch is no longer the newcomer on the heels of TensorFlow, but is today a major force in the world of deep learning, perhaps primarily for research, but also increasingly in applications of production. And with impatient mode having become the default development method in TensorFlow as well as PyTorch, the more Pythonic approach offered by PyTorch’s automatic differencing (autograd) seems to have won the war against static graphs.
Unlike TensorFlow, PyTorch hasn’t seen any major breaks in the core code since the deprecation of the Variable API in version 0.4. (Previously, Variable had to use autograd with tensors; now everything is a tensor.) But that doesn’t mean there haven’t been a few missteps here and there. For example, if you’ve used PyTorch to train on multiple GPUs, you’ve probably encountered the differences between DataParallel and the newer DistributedDataParallel. You should almost always use DistributedDataParallel, but DataParallel isn’t actually deprecated.
Although PyTorch has lagged behind TensorFlow and JAX in XLA/TPU support, the situation has improved significantly as of 2022. PyTorch now supports TPU VM access as well as old-style TPU node support, plus easy command-line deployment. to run your code on CPUs, GPUs, or TPUs without code changes. And if you don’t want to deal with some of the boilerplate code that PyTorch often makes you write, you can turn to higher-level add-ons like PyTorch Lightning, which lets you focus on your actual work rather than rewriting training loops. On the negative side, while work continues on PyTorch Mobile, it is still much less mature than TensorFlow Lite.
In terms of production, PyTorch now has integrations with framework-agnostic platforms like Kubeflow, while the TorchServe project can handle deployment details like scaling, metrics, and inference by bundles, giving you all the goodness of MLOps in a small package that’s maintained by the PyTorch developers themselves. Is PyTorch evolving? Meta has been using PyTorch in production for years, so anyone who tells you that PyTorch can’t handle large-scale workloads is lying to you. Still, there’s reason to think that PyTorch might not be as user-friendly as JAX for very, very large training runs that require banks on GPU or TPU banks.
Finally, there is the elephant in the room. The popularity of PyTorch over the past few years is almost certainly linked to the success of Hugging Face’s Transformers library. Yes, Transformers now also supports TensorFlow and JAX, but it started as a PyTorch project and remains tightly tied to the framework. With the rise of the Transformer architecture, PyTorch’s flexibility for searching, and the ability to integrate so many new models within days or just hours of publishing through Hugging Face’s model hub, it’s easy to understand why PyTorch is spreading everywhere these days. .
Should I use JAX?
If you don’t like TensorFlow, Google might have something else for you. Sort of, anyway. JAX is a deep learning framework that is built, maintained, and used by Google, but it is not officially a Google product. However, if you look at Google/DeepMind articles and publications over the past year, you can’t help but notice that much of Google’s search has moved to JAX. JAX is therefore not an “official” product of Google, but it is what Google researchers use to push the boundaries.
What is JAX, exactly? A simple way to think of JAX is this: imagine a GPU/TPU-accelerated version of NumPy that can, with the wave of a magic wand, magically vectorize a Python function and handle all derived computations on said functions. Finally, it has a JIT (Just-In-Time) component that takes your code and optimizes it for the XLA compiler, resulting in significant performance improvements over TensorFlow and PyTorch. I’ve seen some code run faster by a factor of four or five just by re-implementing it in JAX without any real optimization work taking place.
Since JAX operates at the NumPy level, JAX code is written at a much lower level than TensorFlow/Keras and, yes, even PyTorch. Luckily, there’s a small but growing ecosystem of surrounding projects that add extra stuff. Want neural network libraries? There’s Flax from Google and Haiku from DeepMind (also Google). There’s Optax for all your optimization needs, and PIX for image processing, and more. Once you work with something like Flax, building neural networks becomes relatively easy to master. Just be aware that there are still a few rough edges. Veterans talk a lot about how JAX handles random numbers differently than many other frameworks, eg.
Should you convert everything to JAX and use this advanced technology? Well, maybe, if you are immersed in research involving large-scale models that require huge resources to train. JAX’s advancements in areas such as deterministic training and other situations requiring thousands of TPU pods are likely worth the change on their own.
TensorFlow vs. PyTorch vs. JAX
What’s the takeaway, then? Which deep learning framework should you use? Unfortunately, I don’t think there is a definitive answer. It all depends on the type of problem you are working on, the scale at which you plan to deploy your models, and even the compute platforms you are targeting.
However, I don’t think it’s controversial to say that if you’re working in the text and image realms, and you’re doing small to medium scale research with a view to deploying these models in production, then PyTorch is probably your best bet right now. He just hits the sweet spot in this space these days.
If, however, you need to squeeze all the performance out of low-computing devices, I’d point you to TensorFlow with its rock-solid TensorFlow Lite package. And at the other end of the scale, if you’re working on training models that contain tens or hundreds of billions of parameters or more, and you’re training them primarily for research purposes, then maybe it’s be time for you to give JAX a whirl.
Copyright © 2022 IDG Communications, Inc.