PyTorch Performance Features and How They Interact

PyTorch Performance Features and How They Interact

April 14, 2023
Machine Learning Productionization
Pytorch, Optimization

Intro #

PyTorch in 2023 is a complex beast, with many great performance features hidden away. Simple top-N lists are weak content, so I’ve empirically tested the most important PyTorch tuning techniques and settings in all combinations. I’ve benchmarked inference across a handful of different model architectures and sizes, different versions of PyTorch and even different Docker containers.

The resulting high-dimensional cube of performance metrics has over 800 data points — each dot in the following scatter plots represents a unique experiment, a combination of features enabled or not.

Many interesting relationships are evident in the data. The scatter plots generally highlight a dominant effect with color, with more subtle effects remaining visible in the structure of the points.

The article is structured as a set of high-level recommendations but the headlines are really 101-level PyTorch content — the deeper value is in the secondary interactions, so you may just have to read the text.

Always use Mixed Precision #

torch.autocast() is the modern, automatic way to run your model with mixed precision while ensuring gradient accuracy.

Reduced precision on modern GPU architectures increases throughput and reduces memory usage (which also increases effective memory bandwidth). torch.autocast will detect which layers and ops can be run in reduced precision (fp16/bf16/fp8) and seamlessly cast tensors as required.

Mixed-precision inference is enabled via a simple context-manager:

with torch.autocast("cuda"):
    output = model(input)
    loss = loss_fn(output, target)
  • Despite torch.autocast() being “automatic” it is worth closely checking that output remains acceptable — I’ve seen several cases where language models and CLIP-guided GAN systems degrade unacceptably with autocast.

  • If training, use a GradScaler if your gradients could underflow in fp16.

Use autocast('cuda', dtype=torch.bfloat16) if your model needs the greater value range of bf16 instead of fp16’s greater precision (>= Ampere architecture).

torch.autocast() has a marked throughput impact across model architectures:

In the scatter plot above several points are clear:

  • Mixed precision is essential for good inference throughput on GPUs — no surprise here, the hardware that accelerates lower-precision compute has been around for many years and generations by now.
  • A lot of unexplained variance remains in the data, and this is largely explained in later sections as we dig into the other significant performance settings.

Use Channels-Last memory format for convolutional models #

Switch from the default NCHW layout to channels-last NHWC memory format, improving data locality and unlocking optimized convolution kernels.

Here’s how to convert your model and input to channels-last format:

model =
with torch.autocast('cuda'):
    output = model(

Note that this is only an internal memory format change — from an API perspective tensor shapes remain the same and tensor indexing is unchanged, including in your output tensor. This super convenient fact is due to PyTorch separating and hiding tensor storage details from client code. Feel free to continue pretending you are accessing contiguous memory, and the striding magic will work seamlessly (or throw an exception).

  • Channels-last works best for convolution-heavy models with float16 precision.
  • Be sure to match model (layer) memory format to input tensor memory format (both contiguous or both channels-last).

After removing the effect of mixed precision (all compute is at float16), torch.channels_last is the next most important factor for convolution-heavy models:

Channels-last seems to improve throughput substantially for all models and conditions, but too much noise/structure remains in the data in this simple chart to be sure. A large part of this can be explained by breaking out the different docker containers used. The chart below shows that containers built by NVIDIA provide much higher throughput in combination with channels-last (these are accessible via NGC).

Use cuDNN Benchmarking for convolutional models #

Runtime benchmarking for cuDNN convolutions will automatically select the fastest kernel implementations for your tensor dimensions and hardware.

cuDNN is a library of optimized kernels for deep neural networks created by NVIDIA. Turning on cuDNN benchmarking enables a just-in-time (JIT) profile-based optimization process to select the best kernel implementations from cuDNN for your specific layers, tensor sizes and hardware.

Enabling it is easy:

torch.backends.cudnn.benchmark = True
output = model(input) # benchmarking occurs JIT during first execution
output = model(input) # subsequent executions use the fastest available kernels

cuDNN benchmarking will incur a significant first-run delay while the model is optimized, and this profiling cost will be incurred for all executions with changed tensor dimensions.

There is no way to persist the profiled optimizations, so in production you’ll want to incur this delay during a warm-up phase.

cuDNN benchmarking is easy to experiment with, so is almost always worth trying if you can tolerate the container warm-up time.

Turning off channels-last and only using NGC containers, cuDNN benchmarking still provides some boost to the convolutional models:

Focusing on the convolutional models and reintroducing channels-last, we can see that both techniques are independently valuable:

Use PyTorch 2.0 (Or 2.1) if possible #

Introduced in PyTorch 2.0, torch.compile can deliver substantial improvements in inference and training throughput.

torch.compile supersedes previous PyTorch model compilation efforts (e.g. TorchScript) and aims for both very high ease of use and excellent performance. Previous PyTorch-native techniques often needed significant code changes and performance was not comparable to more-involved 3rd-party approaches like TensorRT compilation, so torch.compile aims for the best of both worlds.

As of April 2023 torch.compile doesn’t yet fully succeed in either of these aims — empirical evidence below.

Using torch.compile could not be easier:

model = torch.compile(model)
output = model(input)
loss = loss_fn(output, target)

With the above simple one-liner, torch.compile will seamlessly just-in-time (JIT) compile arbitrary models and PyTorch code, splitting the computational graph as required into compilable and Python-native subgraphs. This compilation supports both the backward and forward pass, so is suitable for inference and optimization unlike previous approaches. It can also deal with dynamic tensor shapes if desired.

The default TorchInductor backend compiles graphs into high performance Triton-based kernels for on-GPU execution. OpenAI’s Triton is an abstraction over NVIDIA’s CUDA, so in the future this may allow highly efficient PyTorch execution on a wider range of hardware.

Performance and compatibility are a mixed bag, and this is why I claim torch.compile doesn’t yet fully deliver on either the ease-of-use or performance aims. The chart below shows the effect of compilation, and is split vertically to separate PyTorch 2.0.0 and 2.1.0 (unreleased, nightly).

We can make several observations:

  • PyTorch version matters: 2.0.0 doesn’t compile the language models successfully. 2.1.0 increases coverage but doesn’t deliver any throughput gains for the language models.
  • The convolutional models show strong gains, but additional structure is clear in the data: torch.compile is not a replacement for using other parts of PyTorch correctly.
  • Channels-last and inference-mode are disabled here, since compilation fails.

When considering only torch.compiled data points, the continued importance of using eval mode correctly for inference becomes clear:

It’s worth digging a little into eval mode and a few of the other inference settings.

torch.inference_mode(), torch.no_grad() and model.eval() are not interchangeable #

While model.eval(), torch.inference_mode(), and torch.no_grad() seem like they do similar things, they are all important and independently useful.

Use model.eval() to prepare modules and layers for inference. This has layer-specific effects, eg. disabling dropout layers, altering batch-norm behavior etc. If you are doing inference, you definitely don’t want to forget this.

model = model.eval()
output = model(input)

model.eval() improves inference throughput for all model architectures:

Use torch.inference_mode() to prevent gradients from being computed/stored during inference, allowing PyTorch to skip some tensor book-keeping and checks. This largely replaces torch.no_grad() as the way to tell PyTorch that the execution will not require gradient backprop.

with torch.inference_mode():
    output = model(input)

torch.inference_mode() less consistently improves throughput across architectures:

NB. The above model.eval() and torch.inference_mode() improvements are often cumulative — see the chart below.

Only consider torch.no_grad() if torch.inference_mode() doesn’t work. If you want the performance benefits of torch.inference_mode() for a section of code but need to use the tensor results in grad mode later, torch.no_grad() might be for you. See Autograd mechanics for more detail.

Final thoughts #

For me this article raised many more questions than it answered — the anomalies and hidden relationships highlighted above are just begging to be traced and understood more deeply.

These recommendations are not ground-breaking, they are arguably just “using PyTorch correctly,” but I hope I’ve made a few key points clearly and empirically:

  • Getting everything right is much better than getting everything wrong — obviously, but this final chart shows the massive throughput difference between baseline and the optimal configuration:
  • There is no all-encompassing “go fast” option for inference in PyTorch — especially considering different model architectures, you will need to iteratively experiment and measure to get best performance. This remains true in PyTorch 2.0 and beyond. To be truly effective in this process you also need to do some profiling (using a tool like NVIDIA’s Nsight Systems) — measurements can tell you when performance has changed but can’t tell you what to try next.

  • System/architecture issues are important too — for example, how your containers or core libraries were built can impact performance as much as correct code. Use the latest PyTorch, CUDA, cuDNN and drivers if you can.

Thanks for reading, and I hope you found this useful!

© Paul Bridger 2020